suhail commited on
Commit
676582c
·
1 Parent(s): 87238f5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +18 -0
  2. .env +16 -0
  3. .env.example +29 -4
  4. README.md +211 -8
  5. alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py +59 -0
  6. alembic/versions/20260114_1115_37ca2e18468d_description.py +24 -0
  7. alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py +27 -0
  8. alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py +26 -0
  9. alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py +28 -0
  10. alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc +0 -0
  11. alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc +0 -0
  12. alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc +0 -0
  13. alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc +0 -0
  14. alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc +0 -0
  15. alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc +0 -0
  16. alembic/versions/tmpclaude-82e1-cwd +1 -0
  17. package-lock.json +6 -0
  18. requirements.txt +7 -1
  19. src/__pycache__/main.cpython-313.pyc +0 -0
  20. src/agent/__init__.py +0 -0
  21. src/agent/__pycache__/__init__.cpython-313.pyc +0 -0
  22. src/agent/__pycache__/agent_config.cpython-313.pyc +0 -0
  23. src/agent/__pycache__/agent_runner.cpython-313.pyc +0 -0
  24. src/agent/agent_config.py +124 -0
  25. src/agent/agent_runner.py +281 -0
  26. src/agent/providers/__init__.py +0 -0
  27. src/agent/providers/__pycache__/__init__.cpython-313.pyc +0 -0
  28. src/agent/providers/__pycache__/base.cpython-313.pyc +0 -0
  29. src/agent/providers/__pycache__/cohere.cpython-313.pyc +0 -0
  30. src/agent/providers/__pycache__/gemini.cpython-313.pyc +0 -0
  31. src/agent/providers/__pycache__/openrouter.cpython-313.pyc +0 -0
  32. src/agent/providers/base.py +105 -0
  33. src/agent/providers/cohere.py +232 -0
  34. src/agent/providers/gemini.py +285 -0
  35. src/agent/providers/openrouter.py +264 -0
  36. src/api/routes/__pycache__/chat.cpython-313.pyc +0 -0
  37. src/api/routes/__pycache__/conversations.cpython-313.pyc +0 -0
  38. src/api/routes/chat.py +281 -0
  39. src/api/routes/conversations.py +291 -0
  40. src/core/__pycache__/config.cpython-313.pyc +0 -0
  41. src/core/__pycache__/exceptions.cpython-313.pyc +0 -0
  42. src/core/__pycache__/security.cpython-313.pyc +0 -0
  43. src/core/config.py +15 -0
  44. src/core/exceptions.py +125 -0
  45. src/core/security.py +52 -1
  46. src/main.py +67 -2
  47. src/mcp/__init__.py +130 -0
  48. src/mcp/__pycache__/__init__.cpython-313.pyc +0 -0
  49. src/mcp/__pycache__/tool_registry.cpython-313.pyc +0 -0
  50. src/mcp/tool_registry.py +140 -0
.dockerignore CHANGED
@@ -43,3 +43,21 @@ alembic/versions/*.pyc
43
 
44
  # Documentation
45
  docs/_build/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Documentation
45
  docs/_build/
46
+ # Python
47
+ __pycache__/
48
+ *.pyc
49
+ *.pyo
50
+ *.pyd
51
+
52
+ # Environment
53
+ .env
54
+
55
+ # Alembic cache
56
+ alembic/versions/__pycache__/
57
+
58
+ # Node
59
+ node_modules/
60
+ package-lock.json
61
+
62
+ # Temp / AI tools
63
+ tmpclaude-*
.env CHANGED
@@ -2,6 +2,7 @@
2
  # For local PostgreSQL: postgresql://user:password@localhost:5432/todo_db
3
  # For Neon: Use your Neon connection string from the dashboard
4
  DATABASE_URL=postgresql://neondb_owner:npg_MmFvJBHT8Y0k@ep-silent-thunder-ab0rbvrp-pooler.eu-west-2.aws.neon.tech/neondb?sslmode=require&channel_binding=require
 
5
  # Application Settings
6
  APP_NAME=Task CRUD API
7
  DEBUG=True
@@ -11,3 +12,18 @@ CORS_ORIGINS=http://localhost:3000
11
  BETTER_AUTH_SECRET=zMdW1P03wJvWJnLKzQ8YYO26vHeinqmR
12
  JWT_ALGORITHM=HS256
13
  JWT_EXPIRATION_DAYS=7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # For local PostgreSQL: postgresql://user:password@localhost:5432/todo_db
3
  # For Neon: Use your Neon connection string from the dashboard
4
  DATABASE_URL=postgresql://neondb_owner:npg_MmFvJBHT8Y0k@ep-silent-thunder-ab0rbvrp-pooler.eu-west-2.aws.neon.tech/neondb?sslmode=require&channel_binding=require
5
+
6
  # Application Settings
7
  APP_NAME=Task CRUD API
8
  DEBUG=True
 
12
  BETTER_AUTH_SECRET=zMdW1P03wJvWJnLKzQ8YYO26vHeinqmR
13
  JWT_ALGORITHM=HS256
14
  JWT_EXPIRATION_DAYS=7
15
+
16
+ # LLM Provider Configuration
17
+ LLM_PROVIDER=gemini
18
+ # FALLBACK_PROVIDER=openrouter
19
+ GEMINI_API_KEY=AIzaSyCAlcHZxp5ELh1GqJwKqBLQziUNi0vnobU
20
+ # OPENROUTER_API_KEY=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
21
+ COHERE_API_KEY=
22
+
23
+ # Agent Configuration
24
+ AGENT_TEMPERATURE=0.7
25
+ AGENT_MAX_TOKENS=500
26
+
27
+ # Conversation Settings
28
+ CONVERSATION_MAX_MESSAGES=20
29
+ CONVERSATION_MAX_TOKENS=2000
.env.example CHANGED
@@ -6,7 +6,32 @@ APP_NAME=Task CRUD API
6
  DEBUG=True
7
  CORS_ORIGINS=http://localhost:3000
8
 
9
- # Authentication (Placeholder for Spec 2)
10
- # JWT_SECRET=your-secret-key-here
11
- # JWT_ALGORITHM=HS256
12
- # JWT_EXPIRATION_MINUTES=1440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  DEBUG=True
7
  CORS_ORIGINS=http://localhost:3000
8
 
9
+ # Authentication
10
+ BETTER_AUTH_SECRET=your-secret-key-here-min-32-characters
11
+ JWT_ALGORITHM=HS256
12
+ JWT_EXPIRATION_DAYS=7
13
+
14
+ # LLM Provider Configuration
15
+ # Primary provider: gemini, openrouter, cohere
16
+ LLM_PROVIDER=gemini
17
+
18
+ # Optional fallback provider (recommended for production)
19
+ FALLBACK_PROVIDER=openrouter
20
+
21
+ # API Keys (provide at least one for your primary provider)
22
+ GEMINI_API_KEY=your-gemini-api-key-here
23
+ OPENROUTER_API_KEY=your-openrouter-api-key-here
24
+ COHERE_API_KEY=your-cohere-api-key-here
25
+
26
+ # Agent Configuration
27
+ AGENT_TEMPERATURE=0.7
28
+ AGENT_MAX_TOKENS=8192
29
+
30
+ # Conversation Settings (for free-tier constraints)
31
+ CONVERSATION_MAX_MESSAGES=20
32
+ CONVERSATION_MAX_TOKENS=8000
33
+
34
+ # How to get API keys:
35
+ # - Gemini: https://makersuite.google.com/app/apikey (free, no credit card required)
36
+ # - OpenRouter: https://openrouter.ai/ (free models available)
37
+ # - Cohere: https://cohere.com/ (trial only, not recommended for production)
README.md CHANGED
@@ -10,24 +10,227 @@ license: mit
10
 
11
  # TaskFlow Backend API
12
 
13
- FastAPI backend for TaskFlow task management application.
14
 
15
  ## Features
16
 
17
- - User authentication with JWT
18
  - Task CRUD operations
 
19
  - PostgreSQL database with SQLModel ORM
20
  - RESTful API design
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ## Environment Variables
23
 
24
- Configure these in your Space settings:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- - `DATABASE_URL`: PostgreSQL connection string
27
- - `SECRET_KEY`: JWT secret key (generate a secure random string)
28
- - `ALGORITHM`: JWT algorithm (default: HS256)
29
- - `ACCESS_TOKEN_EXPIRE_MINUTES`: Token expiration time (default: 30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ## API Documentation
32
 
33
- Once deployed, visit `/docs` for interactive API documentation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # TaskFlow Backend API
12
 
13
+ FastAPI backend for TaskFlow task management application with AI chatbot integration.
14
 
15
  ## Features
16
 
17
+ - User authentication with JWT and Better Auth
18
  - Task CRUD operations
19
+ - **AI Chatbot Assistant** - Conversational AI for task management
20
  - PostgreSQL database with SQLModel ORM
21
  - RESTful API design
22
+ - Multi-turn conversation support with context management
23
+ - Intent recognition for todo-related requests
24
+
25
+ ## Tech Stack
26
+
27
+ - **Framework**: FastAPI 0.104.1
28
+ - **ORM**: SQLModel 0.0.14
29
+ - **Database**: PostgreSQL (Neon Serverless)
30
+ - **Authentication**: Better Auth + JWT
31
+ - **AI Provider**: Google Gemini (gemini-pro)
32
+ - **Migrations**: Alembic 1.13.0
33
 
34
  ## Environment Variables
35
 
36
+ Configure these in your `.env` file:
37
+
38
+ ### Database
39
+ - `DATABASE_URL`: PostgreSQL connection string (Neon or local)
40
+
41
+ ### Application
42
+ - `APP_NAME`: Application name (default: "Task CRUD API")
43
+ - `DEBUG`: Debug mode (default: True)
44
+ - `CORS_ORIGINS`: Allowed CORS origins (default: "http://localhost:3000")
45
+
46
+ ### Authentication
47
+ - `BETTER_AUTH_SECRET`: Secret key for Better Auth (required)
48
+ - `JWT_ALGORITHM`: JWT algorithm (default: "HS256")
49
+ - `JWT_EXPIRATION_DAYS`: Token expiration in days (default: 7)
50
+
51
+ ### AI Provider Configuration
52
+ - `AI_PROVIDER`: AI provider to use (default: "gemini")
53
+ - `GEMINI_API_KEY`: Google Gemini API key (required if using Gemini)
54
+ - `OPENROUTER_API_KEY`: OpenRouter API key (optional)
55
+ - `COHERE_API_KEY`: Cohere API key (optional)
56
+
57
+ ### Conversation Settings
58
+ - `MAX_CONVERSATION_MESSAGES`: Maximum messages to keep in history (default: 20)
59
+ - `MAX_CONVERSATION_TOKENS`: Maximum tokens to keep in history (default: 8000)
60
+
61
+ ## Setup Instructions
62
+
63
+ ### 1. Install Dependencies
64
+
65
+ ```bash
66
+ pip install -r requirements.txt
67
+ ```
68
+
69
+ ### 2. Configure Environment
70
+
71
+ Create a `.env` file in the `backend/` directory:
72
 
73
+ ```env
74
+ # Database
75
+ DATABASE_URL=postgresql://user:password@localhost:5432/todo_db
76
+
77
+ # Application
78
+ APP_NAME=Task CRUD API
79
+ DEBUG=True
80
+ CORS_ORIGINS=http://localhost:3000
81
+
82
+ # Authentication
83
+ BETTER_AUTH_SECRET=your_secret_key_here
84
+ JWT_ALGORITHM=HS256
85
+ JWT_EXPIRATION_DAYS=7
86
+
87
+ # AI Provider
88
+ AI_PROVIDER=gemini
89
+ GEMINI_API_KEY=your_gemini_api_key_here
90
+
91
+ # Conversation Settings
92
+ MAX_CONVERSATION_MESSAGES=20
93
+ MAX_CONVERSATION_TOKENS=8000
94
+ ```
95
+
96
+ ### 3. Run Database Migrations
97
+
98
+ ```bash
99
+ alembic upgrade head
100
+ ```
101
+
102
+ ### 4. Start the Server
103
+
104
+ ```bash
105
+ uvicorn src.main:app --reload --host 0.0.0.0 --port 8000
106
+ ```
107
+
108
+ The API will be available at `http://localhost:8000`
109
 
110
  ## API Documentation
111
 
112
+ Once running, visit:
113
+ - **Interactive Docs**: `http://localhost:8000/docs`
114
+ - **ReDoc**: `http://localhost:8000/redoc`
115
+
116
+ ## API Endpoints
117
+
118
+ ### Authentication
119
+ - `POST /api/auth/signup` - Register new user
120
+ - `POST /api/auth/login` - Login user
121
+
122
+ ### Tasks
123
+ - `GET /api/{user_id}/tasks` - Get all tasks for user
124
+ - `POST /api/{user_id}/tasks` - Create new task
125
+ - `GET /api/{user_id}/tasks/{task_id}` - Get specific task
126
+ - `PUT /api/{user_id}/tasks/{task_id}` - Update task
127
+ - `DELETE /api/{user_id}/tasks/{task_id}` - Delete task
128
+
129
+ ### AI Chat (New in Phase 1)
130
+ - `POST /api/{user_id}/chat` - Send message to AI assistant
131
+
132
+ #### Chat Request Body
133
+ ```json
134
+ {
135
+ "message": "Can you help me organize my tasks?",
136
+ "conversation_id": 123, // Optional: null for new conversation
137
+ "temperature": 0.7 // Optional: 0.0 to 1.0
138
+ }
139
+ ```
140
+
141
+ #### Chat Response
142
+ ```json
143
+ {
144
+ "conversation_id": 123,
145
+ "message": "I'd be happy to help you organize your tasks!",
146
+ "role": "assistant",
147
+ "timestamp": "2026-01-14T10:30:00Z",
148
+ "token_count": 25,
149
+ "model": "gemini-pro"
150
+ }
151
+ ```
152
+
153
+ ## AI Chatbot Features
154
+
155
+ ### Phase 1 (Current)
156
+ - ✅ Natural conversation with AI assistant
157
+ - ✅ Multi-turn conversations with context retention
158
+ - ✅ Intent recognition for todo-related requests
159
+ - ✅ Conversation history persistence
160
+ - ✅ Automatic history trimming (20 messages / 8000 tokens)
161
+ - ✅ Free-tier AI provider support (Gemini)
162
+
163
+ ### Phase 2 (Coming Soon)
164
+ - 🔄 MCP tools for task CRUD operations
165
+ - 🔄 AI can directly create, update, and delete tasks
166
+ - 🔄 Natural language task management
167
+
168
+ ## Error Handling
169
+
170
+ The API returns standard HTTP status codes:
171
+
172
+ - `200 OK` - Request successful
173
+ - `400 Bad Request` - Invalid request data
174
+ - `401 Unauthorized` - Authentication required or failed
175
+ - `404 Not Found` - Resource not found
176
+ - `429 Too Many Requests` - Rate limit exceeded
177
+ - `500 Internal Server Error` - Server error
178
+
179
+ ## Database Schema
180
+
181
+ ### Users Table
182
+ - `id`: Primary key
183
+ - `email`: Unique email address
184
+ - `name`: User's name
185
+ - `password`: Hashed password
186
+ - `created_at`, `updated_at`: Timestamps
187
+
188
+ ### Tasks Table
189
+ - `id`: Primary key
190
+ - `user_id`: Foreign key to users
191
+ - `title`: Task title
192
+ - `description`: Task description
193
+ - `completed`: Boolean status
194
+ - `created_at`, `updated_at`: Timestamps
195
+
196
+ ### Conversation Table (New)
197
+ - `id`: Primary key
198
+ - `user_id`: Foreign key to users
199
+ - `title`: Conversation title
200
+ - `created_at`, `updated_at`: Timestamps
201
+
202
+ ### Message Table (New)
203
+ - `id`: Primary key
204
+ - `conversation_id`: Foreign key to conversation
205
+ - `role`: "user" or "assistant"
206
+ - `content`: Message text
207
+ - `timestamp`: Message timestamp
208
+ - `token_count`: Token count for the message
209
+
210
+ ## Development
211
+
212
+ ### Running Tests
213
+ ```bash
214
+ pytest
215
+ ```
216
+
217
+ ### Database Migrations
218
+
219
+ Create a new migration:
220
+ ```bash
221
+ alembic revision -m "description"
222
+ ```
223
+
224
+ Apply migrations:
225
+ ```bash
226
+ alembic upgrade head
227
+ ```
228
+
229
+ Rollback migration:
230
+ ```bash
231
+ alembic downgrade -1
232
+ ```
233
+
234
+ ## License
235
+
236
+ MIT License
alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """add_conversation_and_message_tables
2
+
3
+ Revision ID: 48b10b49730f
4
+ Revises: 002_add_user_password
5
+ Create Date: 2026-01-14 10:44:27.010796
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = '48b10b49730f'
14
+ down_revision = '002_add_user_password'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # Create conversation table
21
+ op.create_table(
22
+ 'conversation',
23
+ sa.Column('id', sa.Integer(), nullable=False),
24
+ sa.Column('user_id', sa.Integer(), nullable=False),
25
+ sa.Column('title', sa.String(length=255), nullable=True),
26
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
27
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
28
+ sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
29
+ sa.PrimaryKeyConstraint('id')
30
+ )
31
+ op.create_index('ix_conversation_user_id', 'conversation', ['user_id'])
32
+ op.create_index('ix_conversation_created_at', 'conversation', ['created_at'])
33
+
34
+ # Create message table
35
+ op.create_table(
36
+ 'message',
37
+ sa.Column('id', sa.Integer(), nullable=False),
38
+ sa.Column('conversation_id', sa.Integer(), nullable=False),
39
+ sa.Column('role', sa.String(length=50), nullable=False),
40
+ sa.Column('content', sa.Text(), nullable=False),
41
+ sa.Column('timestamp', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
42
+ sa.Column('token_count', sa.Integer(), nullable=True),
43
+ sa.ForeignKeyConstraint(['conversation_id'], ['conversation.id'], ondelete='CASCADE'),
44
+ sa.PrimaryKeyConstraint('id')
45
+ )
46
+ op.create_index('ix_message_conversation_id', 'message', ['conversation_id'])
47
+ op.create_index('ix_message_timestamp', 'message', ['timestamp'])
48
+
49
+
50
+ def downgrade() -> None:
51
+ # Drop message table first (due to foreign key dependency)
52
+ op.drop_index('ix_message_timestamp', table_name='message')
53
+ op.drop_index('ix_message_conversation_id', table_name='message')
54
+ op.drop_table('message')
55
+
56
+ # Drop conversation table
57
+ op.drop_index('ix_conversation_created_at', table_name='conversation')
58
+ op.drop_index('ix_conversation_user_id', table_name='conversation')
59
+ op.drop_table('conversation')
alembic/versions/20260114_1115_37ca2e18468d_description.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """description
2
+
3
+ Revision ID: 37ca2e18468d
4
+ Revises: 48b10b49730f
5
+ Create Date: 2026-01-14 11:15:03.691055
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = '37ca2e18468d'
14
+ down_revision = '48b10b49730f'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ pass
21
+
22
+
23
+ def downgrade() -> None:
24
+ pass
alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """add metadata column to message table
2
+
3
+ Revision ID: a3c44bf7ddcb
4
+ Revises: 37ca2e18468d
5
+ Create Date: 2026-01-14 17:02:51.060200
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+ from sqlalchemy.dialects import postgresql
11
+
12
+
13
+ # revision identifiers, used by Alembic.
14
+ revision = 'a3c44bf7ddcb'
15
+ down_revision = '37ca2e18468d'
16
+ branch_labels = None
17
+ depends_on = None
18
+
19
+
20
+ def upgrade() -> None:
21
+ # Add metadata column to message table
22
+ op.add_column('message', sa.Column('metadata', postgresql.JSON(astext_type=sa.Text()), nullable=True))
23
+
24
+
25
+ def downgrade() -> None:
26
+ # Remove metadata column from message table
27
+ op.drop_column('message', 'metadata')
alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """rename metadata to tool_metadata in message table
2
+
3
+ Revision ID: e8275e6c143c
4
+ Revises: a3c44bf7ddcb
5
+ Create Date: 2026-01-14 17:12:53.740315
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = 'e8275e6c143c'
14
+ down_revision = 'a3c44bf7ddcb'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # Rename metadata column to tool_metadata
21
+ op.alter_column('message', 'metadata', new_column_name='tool_metadata')
22
+
23
+
24
+ def downgrade() -> None:
25
+ # Rename tool_metadata column back to metadata
26
+ op.alter_column('message', 'tool_metadata', new_column_name='metadata')
alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """add due_date and priority to task table
2
+
3
+ Revision ID: d34db62bd406
4
+ Revises: e8275e6c143c
5
+ Create Date: 2026-01-14 19:00:45.426280
6
+
7
+ """
8
+ from alembic import op
9
+ import sqlalchemy as sa
10
+ from sqlalchemy.dialects import postgresql
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = 'd34db62bd406'
14
+ down_revision = 'e8275e6c143c'
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # Add due_date and priority columns to tasks table
21
+ op.add_column('tasks', sa.Column('due_date', sa.Date(), nullable=True))
22
+ op.add_column('tasks', sa.Column('priority', sa.String(length=20), nullable=False, server_default='medium'))
23
+
24
+
25
+ def downgrade() -> None:
26
+ # Remove due_date and priority columns from tasks table
27
+ op.drop_column('tasks', 'priority')
28
+ op.drop_column('tasks', 'due_date')
alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc ADDED
Binary file (3.65 kB). View file
 
alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc ADDED
Binary file (771 Bytes). View file
 
alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc ADDED
Binary file (4.64 kB). View file
 
alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc ADDED
Binary file (1.25 kB). View file
 
alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc ADDED
Binary file (1.04 kB). View file
 
alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc ADDED
Binary file (1.48 kB). View file
 
alembic/versions/tmpclaude-82e1-cwd ADDED
@@ -0,0 +1 @@
 
 
1
+ /d/Agentic_ai_learning/hacathoon_2/evolution-of-todo/phase-2-full-stack-web-app/backend/alembic/versions
package-lock.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "backend",
3
+ "lockfileVersion": 3,
4
+ "requires": true,
5
+ "packages": {}
6
+ }
requirements.txt CHANGED
@@ -13,4 +13,10 @@ passlib[bcrypt]==1.7.4
13
  python-multipart==0.0.6
14
 
15
  # Add this line (or replace if bcrypt is already listed)
16
- bcrypt==4.3.0 # Last stable version before the 5.0 break
 
 
 
 
 
 
 
13
  python-multipart==0.0.6
14
 
15
  # Add this line (or replace if bcrypt is already listed)
16
+ bcrypt==4.3.0 # Last stable version before the 5.0 break
17
+
18
+ # AI Chatbot dependencies
19
+ google-generativeai==0.3.2 # Gemini API client
20
+ tiktoken==0.5.2 # Token counting (optional)
21
+ mcp==1.20.0 # Official MCP SDK for tool server
22
+ cohere==4.37 # Cohere API client (optional provider)
src/__pycache__/main.cpython-313.pyc CHANGED
Binary files a/src/__pycache__/main.cpython-313.pyc and b/src/__pycache__/main.cpython-313.pyc differ
 
src/agent/__init__.py ADDED
File without changes
src/agent/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (206 Bytes). View file
 
src/agent/__pycache__/agent_config.cpython-313.pyc ADDED
Binary file (4.41 kB). View file
 
src/agent/__pycache__/agent_runner.cpython-313.pyc ADDED
Binary file (10.1 kB). View file
 
src/agent/agent_config.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Configuration
3
+
4
+ Configuration dataclass for agent behavior and LLM provider settings.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+
11
+ @dataclass
12
+ class AgentConfiguration:
13
+ """
14
+ Configuration for the AI agent.
15
+
16
+ This dataclass defines all configurable parameters for agent behavior,
17
+ including LLM provider settings, conversation limits, and system prompts.
18
+ """
19
+
20
+ # Provider settings
21
+ provider: str = "gemini" # Options: gemini, openrouter, cohere
22
+ fallback_provider: Optional[str] = None # Optional fallback provider
23
+ model: Optional[str] = None # Model name (provider-specific)
24
+
25
+ # API keys (loaded from environment)
26
+ gemini_api_key: Optional[str] = None
27
+ openrouter_api_key: Optional[str] = None
28
+ cohere_api_key: Optional[str] = None
29
+
30
+ # Generation parameters
31
+ temperature: float = 0.7 # Sampling temperature (0.0 to 1.0)
32
+ max_tokens: int = 8192 # Maximum tokens in response
33
+
34
+ # Conversation history limits (for free-tier constraints)
35
+ max_messages: int = 20 # Maximum messages to keep in history
36
+ max_conversation_tokens: int = 8000 # Maximum tokens in conversation history
37
+
38
+ # System prompt
39
+ system_prompt: str = """You are a helpful AI assistant for managing tasks.
40
+ You can help users create, view, complete, update, and delete tasks using natural language.
41
+
42
+ Available tools:
43
+ - add_task: Create a new task
44
+ - list_tasks: View all tasks (with optional filtering)
45
+ - complete_task: Mark a task as completed
46
+ - delete_task: Remove a task
47
+ - update_task: Modify task properties
48
+
49
+ Always respond in a friendly, conversational manner and confirm actions taken."""
50
+
51
+ # Retry settings
52
+ max_retries: int = 3 # Maximum retries on rate limit errors
53
+ retry_delay: float = 1.0 # Delay between retries (seconds)
54
+
55
+ def get_provider_api_key(self, provider_name: str) -> Optional[str]:
56
+ """
57
+ Get API key for a specific provider.
58
+
59
+ Args:
60
+ provider_name: Provider name (gemini, openrouter, cohere)
61
+
62
+ Returns:
63
+ API key or None if not configured
64
+ """
65
+ if provider_name == "gemini":
66
+ return self.gemini_api_key
67
+ elif provider_name == "openrouter":
68
+ return self.openrouter_api_key
69
+ elif provider_name == "cohere":
70
+ return self.cohere_api_key
71
+ return None
72
+
73
+ def get_provider_model(self, provider_name: str) -> str:
74
+ """
75
+ Get default model for a specific provider.
76
+
77
+ Args:
78
+ provider_name: Provider name
79
+
80
+ Returns:
81
+ Model identifier
82
+ """
83
+ if self.model:
84
+ return self.model
85
+
86
+ # Default models per provider
87
+ defaults = {
88
+ "gemini": "gemini-flash-latest",
89
+ "openrouter": "google/gemini-flash-1.5",
90
+ "cohere": "command-r-plus"
91
+ }
92
+ return defaults.get(provider_name, "gemini-flash-latest")
93
+
94
+ def validate(self) -> bool:
95
+ """
96
+ Validate configuration.
97
+
98
+ Returns:
99
+ True if configuration is valid
100
+
101
+ Raises:
102
+ ValueError: If configuration is invalid
103
+ """
104
+ # Check primary provider has API key
105
+ primary_key = self.get_provider_api_key(self.provider)
106
+ if not primary_key:
107
+ raise ValueError(f"API key not configured for primary provider: {self.provider}")
108
+
109
+ # Validate temperature range
110
+ if not 0.0 <= self.temperature <= 1.0:
111
+ raise ValueError(f"Temperature must be between 0.0 and 1.0, got: {self.temperature}")
112
+
113
+ # Validate max_tokens
114
+ if self.max_tokens <= 0:
115
+ raise ValueError(f"max_tokens must be positive, got: {self.max_tokens}")
116
+
117
+ # Validate conversation limits
118
+ if self.max_messages <= 0:
119
+ raise ValueError(f"max_messages must be positive, got: {self.max_messages}")
120
+
121
+ if self.max_conversation_tokens <= 0:
122
+ raise ValueError(f"max_conversation_tokens must be positive, got: {self.max_conversation_tokens}")
123
+
124
+ return True
src/agent/agent_runner.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent Runner
3
+
4
+ Core orchestrator for AI agent execution with tool calling support.
5
+ Manages the full request cycle: LLM generation → tool execution → final response.
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Dict, Any, Optional
10
+ import asyncio
11
+
12
+ from .agent_config import AgentConfiguration
13
+ from .providers.base import LLMProvider
14
+ from .providers.gemini import GeminiProvider
15
+ from .providers.openrouter import OpenRouterProvider
16
+ from .providers.cohere import CohereProvider
17
+ from ..mcp.tool_registry import MCPToolRegistry, ToolExecutionResult
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class AgentRunner:
23
+ """
24
+ Agent execution orchestrator with tool calling support.
25
+
26
+ This class manages the full agent request cycle:
27
+ 1. Generate LLM response with tool definitions
28
+ 2. If tool calls requested, execute tools with user context injection
29
+ 3. Generate final response with tool results
30
+ 4. Handle rate limiting with fallback providers
31
+ """
32
+
33
+ def __init__(self, config: AgentConfiguration, tool_registry: MCPToolRegistry):
34
+ """
35
+ Initialize the agent runner.
36
+
37
+ Args:
38
+ config: Agent configuration
39
+ tool_registry: MCP tool registry
40
+ """
41
+ self.config = config
42
+ self.tool_registry = tool_registry
43
+ self.primary_provider = self._create_provider(config.provider)
44
+ self.fallback_provider = None
45
+
46
+ if config.fallback_provider:
47
+ self.fallback_provider = self._create_provider(config.fallback_provider)
48
+
49
+ logger.info(f"Initialized AgentRunner with provider: {config.provider}")
50
+
51
+ def _create_provider(self, provider_name: str) -> LLMProvider:
52
+ """
53
+ Create an LLM provider instance.
54
+
55
+ Args:
56
+ provider_name: Provider name (gemini, openrouter, cohere)
57
+
58
+ Returns:
59
+ LLMProvider instance
60
+
61
+ Raises:
62
+ ValueError: If provider is not supported or API key is missing
63
+ """
64
+ api_key = self.config.get_provider_api_key(provider_name)
65
+ if not api_key:
66
+ raise ValueError(f"API key not configured for provider: {provider_name}")
67
+
68
+ model = self.config.get_provider_model(provider_name)
69
+
70
+ if provider_name == "gemini":
71
+ return GeminiProvider(
72
+ api_key=api_key,
73
+ model=model,
74
+ temperature=self.config.temperature,
75
+ max_tokens=self.config.max_tokens
76
+ )
77
+ elif provider_name == "openrouter":
78
+ return OpenRouterProvider(
79
+ api_key=api_key,
80
+ model=model,
81
+ temperature=self.config.temperature,
82
+ max_tokens=self.config.max_tokens
83
+ )
84
+ elif provider_name == "cohere":
85
+ return CohereProvider(
86
+ api_key=api_key,
87
+ model=model,
88
+ temperature=self.config.temperature,
89
+ max_tokens=self.config.max_tokens
90
+ )
91
+ else:
92
+ raise ValueError(f"Unsupported provider: {provider_name}")
93
+
94
+ async def execute(
95
+ self,
96
+ messages: List[Dict[str, str]],
97
+ user_id: int,
98
+ system_prompt: Optional[str] = None
99
+ ) -> Dict[str, Any]:
100
+ """
101
+ Execute agent request with tool calling support.
102
+
103
+ SECURITY: user_id is injected by backend, never from LLM output.
104
+
105
+ Args:
106
+ messages: Conversation history [{"role": "user", "content": "..."}]
107
+ user_id: User ID (injected by backend for security)
108
+ system_prompt: Optional system prompt (uses config default if not provided)
109
+
110
+ Returns:
111
+ Dict with response content and metadata
112
+ """
113
+ prompt = system_prompt or self.config.system_prompt
114
+ provider = self.primary_provider
115
+
116
+ try:
117
+ # Get tool definitions
118
+ tool_definitions = self.tool_registry.get_tool_definitions()
119
+
120
+ logger.info(f"Executing agent for user {user_id} with {len(tool_definitions)} tools")
121
+
122
+ # Generate initial response with tool definitions
123
+ response = await provider.generate_response_with_tools(
124
+ messages=messages,
125
+ system_prompt=prompt,
126
+ tools=tool_definitions
127
+ )
128
+
129
+ # Check if tool calls were requested
130
+ if response.tool_calls:
131
+ logger.info(f"Agent requested {len(response.tool_calls)} tool calls")
132
+
133
+ # Execute all tool calls
134
+ tool_results = []
135
+ for tool_call in response.tool_calls:
136
+ result = await self.tool_registry.execute_tool(
137
+ tool_name=tool_call["name"],
138
+ arguments=tool_call["arguments"],
139
+ user_id=user_id # Inject user context for security
140
+ )
141
+ tool_results.append(result)
142
+
143
+ # Generate final response with tool results
144
+ final_response = await provider.generate_response_with_tool_results(
145
+ messages=messages,
146
+ tool_calls=response.tool_calls,
147
+ tool_results=tool_results
148
+ )
149
+
150
+ return {
151
+ "content": final_response.content,
152
+ "tool_calls": response.tool_calls,
153
+ "tool_results": tool_results,
154
+ "provider": provider.get_provider_name()
155
+ }
156
+
157
+ # No tool calls, return direct response
158
+ logger.info("Agent generated direct response (no tool calls)")
159
+ return {
160
+ "content": response.content,
161
+ "tool_calls": None,
162
+ "tool_results": None,
163
+ "provider": provider.get_provider_name()
164
+ }
165
+
166
+ except Exception as e:
167
+ logger.error(f"Agent execution failed with primary provider: {str(e)}")
168
+
169
+ # Try fallback provider if configured
170
+ if self.fallback_provider:
171
+ logger.info("Attempting fallback provider")
172
+ try:
173
+ return await self._execute_with_provider(
174
+ provider=self.fallback_provider,
175
+ messages=messages,
176
+ user_id=user_id,
177
+ system_prompt=prompt
178
+ )
179
+ except Exception as fallback_error:
180
+ logger.error(f"Fallback provider also failed: {str(fallback_error)}")
181
+ raise
182
+
183
+ raise
184
+
185
+ async def _execute_with_provider(
186
+ self,
187
+ provider: LLMProvider,
188
+ messages: List[Dict[str, str]],
189
+ user_id: int,
190
+ system_prompt: str
191
+ ) -> Dict[str, Any]:
192
+ """
193
+ Execute agent request with a specific provider.
194
+
195
+ Args:
196
+ provider: LLM provider to use
197
+ messages: Conversation history
198
+ user_id: User ID
199
+ system_prompt: System prompt
200
+
201
+ Returns:
202
+ Dict with response content and metadata
203
+ """
204
+ tool_definitions = self.tool_registry.get_tool_definitions()
205
+
206
+ # Generate initial response
207
+ response = await provider.generate_response_with_tools(
208
+ messages=messages,
209
+ system_prompt=system_prompt,
210
+ tools=tool_definitions
211
+ )
212
+
213
+ # Handle tool calls
214
+ if response.tool_calls:
215
+ tool_results = []
216
+ for tool_call in response.tool_calls:
217
+ result = await self.tool_registry.execute_tool(
218
+ tool_name=tool_call["name"],
219
+ arguments=tool_call["arguments"],
220
+ user_id=user_id
221
+ )
222
+ tool_results.append(result)
223
+
224
+ final_response = await provider.generate_response_with_tool_results(
225
+ messages=messages,
226
+ tool_calls=response.tool_calls,
227
+ tool_results=tool_results
228
+ )
229
+
230
+ return {
231
+ "content": final_response.content,
232
+ "tool_calls": response.tool_calls,
233
+ "tool_results": tool_results,
234
+ "provider": provider.get_provider_name()
235
+ }
236
+
237
+ return {
238
+ "content": response.content,
239
+ "tool_calls": None,
240
+ "tool_results": None,
241
+ "provider": provider.get_provider_name()
242
+ }
243
+
244
+ async def execute_simple(
245
+ self,
246
+ messages: List[Dict[str, str]],
247
+ system_prompt: Optional[str] = None
248
+ ) -> str:
249
+ """
250
+ Execute a simple agent request without tool calling.
251
+
252
+ Args:
253
+ messages: Conversation history
254
+ system_prompt: Optional system prompt
255
+
256
+ Returns:
257
+ Response content as string
258
+ """
259
+ prompt = system_prompt or self.config.system_prompt
260
+ provider = self.primary_provider
261
+
262
+ try:
263
+ response = await provider.generate_simple_response(
264
+ messages=messages,
265
+ system_prompt=prompt
266
+ )
267
+ return response.content or ""
268
+
269
+ except Exception as e:
270
+ logger.error(f"Simple execution failed: {str(e)}")
271
+
272
+ # Try fallback provider
273
+ if self.fallback_provider:
274
+ logger.info("Attempting fallback provider for simple execution")
275
+ response = await self.fallback_provider.generate_simple_response(
276
+ messages=messages,
277
+ system_prompt=prompt
278
+ )
279
+ return response.content or ""
280
+
281
+ raise
src/agent/providers/__init__.py ADDED
File without changes
src/agent/providers/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (216 Bytes). View file
 
src/agent/providers/__pycache__/base.cpython-313.pyc ADDED
Binary file (4.34 kB). View file
 
src/agent/providers/__pycache__/cohere.cpython-313.pyc ADDED
Binary file (8.05 kB). View file
 
src/agent/providers/__pycache__/gemini.cpython-313.pyc ADDED
Binary file (10.5 kB). View file
 
src/agent/providers/__pycache__/openrouter.cpython-313.pyc ADDED
Binary file (10.7 kB). View file
 
src/agent/providers/base.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Provider Base Class
3
+
4
+ Abstract base class for LLM provider implementations.
5
+ Defines the interface for generating responses with function calling support.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import List, Dict, Any, Optional
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class LLMResponse:
15
+ """Response from an LLM provider."""
16
+ content: Optional[str] = None
17
+ tool_calls: Optional[List[Dict[str, Any]]] = None
18
+ finish_reason: Optional[str] = None
19
+ usage: Optional[Dict[str, int]] = None
20
+
21
+
22
+ class LLMProvider(ABC):
23
+ """
24
+ Abstract base class for LLM providers.
25
+
26
+ All provider implementations (Gemini, OpenRouter, Cohere) must
27
+ implement these methods to support function calling and tool execution.
28
+ """
29
+
30
+ def __init__(self, api_key: str, model: str, temperature: float = 0.7, max_tokens: int = 8192):
31
+ """
32
+ Initialize the LLM provider.
33
+
34
+ Args:
35
+ api_key: API key for the provider
36
+ model: Model identifier (e.g., "gemini-1.5-flash")
37
+ temperature: Sampling temperature (0.0 to 1.0)
38
+ max_tokens: Maximum tokens in response
39
+ """
40
+ self.api_key = api_key
41
+ self.model = model
42
+ self.temperature = temperature
43
+ self.max_tokens = max_tokens
44
+
45
+ @abstractmethod
46
+ async def generate_response_with_tools(
47
+ self,
48
+ messages: List[Dict[str, str]],
49
+ system_prompt: str,
50
+ tools: List[Dict[str, Any]]
51
+ ) -> LLMResponse:
52
+ """
53
+ Generate a response with function calling support.
54
+
55
+ Args:
56
+ messages: Conversation history [{"role": "user", "content": "..."}]
57
+ system_prompt: System instructions for the agent
58
+ tools: Tool definitions for function calling
59
+
60
+ Returns:
61
+ LLMResponse with content and/or tool_calls
62
+ """
63
+ pass
64
+
65
+ @abstractmethod
66
+ async def generate_response_with_tool_results(
67
+ self,
68
+ messages: List[Dict[str, str]],
69
+ tool_calls: List[Dict[str, Any]],
70
+ tool_results: List[Dict[str, Any]]
71
+ ) -> LLMResponse:
72
+ """
73
+ Generate a final response after tool execution.
74
+
75
+ Args:
76
+ messages: Original conversation history
77
+ tool_calls: Tool calls that were made
78
+ tool_results: Results from tool execution
79
+
80
+ Returns:
81
+ LLMResponse with final content
82
+ """
83
+ pass
84
+
85
+ @abstractmethod
86
+ async def generate_simple_response(
87
+ self,
88
+ messages: List[Dict[str, str]],
89
+ system_prompt: str
90
+ ) -> LLMResponse:
91
+ """
92
+ Generate a simple response without function calling.
93
+
94
+ Args:
95
+ messages: Conversation history
96
+ system_prompt: System instructions
97
+
98
+ Returns:
99
+ LLMResponse with content
100
+ """
101
+ pass
102
+
103
+ def get_provider_name(self) -> str:
104
+ """Get the provider name (e.g., 'gemini', 'openrouter', 'cohere')."""
105
+ return self.__class__.__name__.replace("Provider", "").lower()
src/agent/providers/cohere.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cohere Provider Implementation
3
+
4
+ Cohere API provider with function calling support.
5
+ Optional provider (trial only, not recommended for production).
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Dict, Any
10
+ import cohere
11
+
12
+ from .base import LLMProvider, LLMResponse
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class CohereProvider(LLMProvider):
18
+ """
19
+ Cohere API provider implementation.
20
+
21
+ Features:
22
+ - Native function calling support
23
+ - Trial tier only (not recommended for production)
24
+ - Model: command-r-plus (best for function calling)
25
+
26
+ Note: Cohere requires a paid plan after trial expires.
27
+ Use Gemini or OpenRouter for free-tier operation.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ api_key: str,
33
+ model: str = "command-r-plus",
34
+ temperature: float = 0.7,
35
+ max_tokens: int = 8192
36
+ ):
37
+ super().__init__(api_key, model, temperature, max_tokens)
38
+ self.client = cohere.Client(api_key)
39
+ logger.info(f"Initialized CohereProvider with model: {model}")
40
+
41
+ def _convert_tools_to_cohere_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
+ """
43
+ Convert MCP tool definitions to Cohere tool format.
44
+
45
+ Args:
46
+ tools: MCP tool definitions
47
+
48
+ Returns:
49
+ List of Cohere-formatted tool definitions
50
+ """
51
+ return [
52
+ {
53
+ "name": tool["name"],
54
+ "description": tool["description"],
55
+ "parameter_definitions": tool["parameters"].get("properties", {})
56
+ }
57
+ for tool in tools
58
+ ]
59
+
60
+ async def generate_response_with_tools(
61
+ self,
62
+ messages: List[Dict[str, str]],
63
+ system_prompt: str,
64
+ tools: List[Dict[str, Any]]
65
+ ) -> LLMResponse:
66
+ """
67
+ Generate a response with function calling support.
68
+
69
+ Args:
70
+ messages: Conversation history
71
+ system_prompt: System instructions
72
+ tools: Tool definitions
73
+
74
+ Returns:
75
+ LLMResponse with content and/or tool_calls
76
+ """
77
+ try:
78
+ # Convert tools to Cohere format
79
+ cohere_tools = self._convert_tools_to_cohere_format(tools)
80
+
81
+ # Format chat history for Cohere
82
+ chat_history = []
83
+ for msg in messages[:-1]: # All except last message
84
+ chat_history.append({
85
+ "role": "USER" if msg["role"] == "user" else "CHATBOT",
86
+ "message": msg["content"]
87
+ })
88
+
89
+ # Last message is the current user message
90
+ current_message = messages[-1]["content"] if messages else ""
91
+
92
+ # Generate response with function calling
93
+ response = self.client.chat(
94
+ message=current_message,
95
+ chat_history=chat_history,
96
+ preamble=system_prompt,
97
+ model=self.model,
98
+ temperature=self.temperature,
99
+ max_tokens=self.max_tokens,
100
+ tools=cohere_tools
101
+ )
102
+
103
+ # Check for tool calls
104
+ if response.tool_calls:
105
+ tool_calls = [
106
+ {
107
+ "name": tc.name,
108
+ "arguments": tc.parameters
109
+ }
110
+ for tc in response.tool_calls
111
+ ]
112
+ logger.info(f"Cohere requested function calls: {[tc['name'] for tc in tool_calls]}")
113
+ return LLMResponse(
114
+ content=None,
115
+ tool_calls=tool_calls,
116
+ finish_reason="tool_calls"
117
+ )
118
+
119
+ # Regular text response
120
+ content = response.text
121
+ logger.info("Cohere generated text response")
122
+ return LLMResponse(
123
+ content=content,
124
+ finish_reason="COMPLETE"
125
+ )
126
+
127
+ except Exception as e:
128
+ logger.error(f"Cohere API error: {str(e)}")
129
+ raise
130
+
131
+ async def generate_response_with_tool_results(
132
+ self,
133
+ messages: List[Dict[str, str]],
134
+ tool_calls: List[Dict[str, Any]],
135
+ tool_results: List[Dict[str, Any]]
136
+ ) -> LLMResponse:
137
+ """
138
+ Generate a final response after tool execution.
139
+
140
+ Args:
141
+ messages: Original conversation history
142
+ tool_calls: Tool calls that were made
143
+ tool_results: Results from tool execution
144
+
145
+ Returns:
146
+ LLMResponse with final content
147
+ """
148
+ try:
149
+ # Format chat history
150
+ chat_history = []
151
+ for msg in messages:
152
+ chat_history.append({
153
+ "role": "USER" if msg["role"] == "user" else "CHATBOT",
154
+ "message": msg["content"]
155
+ })
156
+
157
+ # Format tool results for Cohere
158
+ tool_results_formatted = [
159
+ {
160
+ "call": {"name": call["name"], "parameters": call["arguments"]},
161
+ "outputs": [{"result": str(result)}]
162
+ }
163
+ for call, result in zip(tool_calls, tool_results)
164
+ ]
165
+
166
+ # Generate final response
167
+ response = self.client.chat(
168
+ message="Based on the tool results, provide a natural language response.",
169
+ chat_history=chat_history,
170
+ model=self.model,
171
+ temperature=self.temperature,
172
+ max_tokens=self.max_tokens,
173
+ tool_results=tool_results_formatted
174
+ )
175
+
176
+ content = response.text
177
+ logger.info("Cohere generated final response after tool execution")
178
+ return LLMResponse(
179
+ content=content,
180
+ finish_reason="COMPLETE"
181
+ )
182
+
183
+ except Exception as e:
184
+ logger.error(f"Cohere API error in tool results: {str(e)}")
185
+ raise
186
+
187
+ async def generate_simple_response(
188
+ self,
189
+ messages: List[Dict[str, str]],
190
+ system_prompt: str
191
+ ) -> LLMResponse:
192
+ """
193
+ Generate a simple response without function calling.
194
+
195
+ Args:
196
+ messages: Conversation history
197
+ system_prompt: System instructions
198
+
199
+ Returns:
200
+ LLMResponse with content
201
+ """
202
+ try:
203
+ # Format chat history
204
+ chat_history = []
205
+ for msg in messages[:-1]:
206
+ chat_history.append({
207
+ "role": "USER" if msg["role"] == "user" else "CHATBOT",
208
+ "message": msg["content"]
209
+ })
210
+
211
+ current_message = messages[-1]["content"] if messages else ""
212
+
213
+ # Generate response
214
+ response = self.client.chat(
215
+ message=current_message,
216
+ chat_history=chat_history,
217
+ preamble=system_prompt,
218
+ model=self.model,
219
+ temperature=self.temperature,
220
+ max_tokens=self.max_tokens
221
+ )
222
+
223
+ content = response.text
224
+ logger.info("Cohere generated simple response")
225
+ return LLMResponse(
226
+ content=content,
227
+ finish_reason="COMPLETE"
228
+ )
229
+
230
+ except Exception as e:
231
+ logger.error(f"Cohere API error: {str(e)}")
232
+ raise
src/agent/providers/gemini.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini Provider Implementation
3
+
4
+ Google Gemini API provider with function calling support.
5
+ Primary provider for free-tier operation (15 RPM, 1M token context).
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Dict, Any
10
+ import google.generativeai as genai
11
+ from google.generativeai.types import FunctionDeclaration, Tool
12
+
13
+ from .base import LLMProvider, LLMResponse
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class GeminiProvider(LLMProvider):
19
+ """
20
+ Google Gemini API provider implementation.
21
+
22
+ Features:
23
+ - Native function calling support
24
+ - 1M token context window
25
+ - Free tier: 15 requests/minute
26
+ - Model: gemini-1.5-flash (recommended for free tier)
27
+ """
28
+
29
+ def __init__(self, api_key: str, model: str = "gemini-flash-latest", temperature: float = 0.7, max_tokens: int = 8192):
30
+ super().__init__(api_key, model, temperature, max_tokens)
31
+ genai.configure(api_key=api_key)
32
+ self.client = genai.GenerativeModel(model)
33
+ logger.info(f"Initialized GeminiProvider with model: {model}")
34
+
35
+ def _sanitize_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
36
+ """
37
+ Sanitize JSON Schema to be Gemini-compatible.
38
+
39
+ Gemini only supports a subset of JSON Schema keywords:
40
+ - Supported: type, description, enum, required, properties, items
41
+ - NOT supported: maxLength, minLength, pattern, format, minimum, maximum, default, etc.
42
+
43
+ Args:
44
+ schema: Original JSON Schema
45
+
46
+ Returns:
47
+ Gemini-compatible schema with unsupported fields removed
48
+ """
49
+ # Fields that Gemini supports
50
+ ALLOWED_FIELDS = {
51
+ "type", "description", "enum", "required",
52
+ "properties", "items"
53
+ }
54
+
55
+ # Create a sanitized copy
56
+ sanitized = {}
57
+
58
+ for key, value in schema.items():
59
+ if key in ALLOWED_FIELDS:
60
+ # Recursively sanitize nested objects
61
+ if key == "properties" and isinstance(value, dict):
62
+ sanitized[key] = {
63
+ prop_name: self._sanitize_schema_for_gemini(prop_schema)
64
+ for prop_name, prop_schema in value.items()
65
+ }
66
+ elif key == "items" and isinstance(value, dict):
67
+ sanitized[key] = self._sanitize_schema_for_gemini(value)
68
+ else:
69
+ sanitized[key] = value
70
+
71
+ return sanitized
72
+
73
+ def _convert_tools_to_gemini_format(self, tools: List[Dict[str, Any]]) -> List[Tool]:
74
+ """
75
+ Convert MCP tool definitions to Gemini function declarations.
76
+
77
+ Sanitizes schemas to remove unsupported JSON Schema keywords.
78
+
79
+ Args:
80
+ tools: MCP tool definitions
81
+
82
+ Returns:
83
+ List of Gemini Tool objects
84
+ """
85
+ function_declarations = []
86
+ for tool in tools:
87
+ # Sanitize parameters to remove unsupported fields
88
+ sanitized_parameters = self._sanitize_schema_for_gemini(tool["parameters"])
89
+
90
+ function_declarations.append(
91
+ FunctionDeclaration(
92
+ name=tool["name"],
93
+ description=tool["description"],
94
+ parameters=sanitized_parameters
95
+ )
96
+ )
97
+
98
+ logger.debug(f"Sanitized tool schema for Gemini: {tool['name']}")
99
+
100
+ return [Tool(function_declarations=function_declarations)]
101
+
102
+ def _convert_messages_to_gemini_format(self, messages: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]:
103
+ """
104
+ Convert standard message format to Gemini format.
105
+
106
+ Args:
107
+ messages: Standard message format [{"role": "user", "content": "..."}]
108
+ system_prompt: System instructions
109
+
110
+ Returns:
111
+ Gemini-formatted messages
112
+ """
113
+ gemini_messages = []
114
+
115
+ # Add system prompt as first user message if provided
116
+ if system_prompt:
117
+ gemini_messages.append({
118
+ "role": "user",
119
+ "parts": [{"text": system_prompt}]
120
+ })
121
+ gemini_messages.append({
122
+ "role": "model",
123
+ "parts": [{"text": "Understood. I'll follow these instructions."}]
124
+ })
125
+
126
+ # Convert messages
127
+ for msg in messages:
128
+ role = "user" if msg["role"] == "user" else "model"
129
+ gemini_messages.append({
130
+ "role": role,
131
+ "parts": [{"text": msg["content"]}]
132
+ })
133
+
134
+ return gemini_messages
135
+
136
+ async def generate_response_with_tools(
137
+ self,
138
+ messages: List[Dict[str, str]],
139
+ system_prompt: str,
140
+ tools: List[Dict[str, Any]]
141
+ ) -> LLMResponse:
142
+ """
143
+ Generate a response with function calling support.
144
+
145
+ Args:
146
+ messages: Conversation history
147
+ system_prompt: System instructions
148
+ tools: Tool definitions
149
+
150
+ Returns:
151
+ LLMResponse with content and/or tool_calls
152
+ """
153
+ try:
154
+ # Convert tools to Gemini format
155
+ gemini_tools = self._convert_tools_to_gemini_format(tools)
156
+
157
+ # Convert messages to Gemini format
158
+ gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt)
159
+
160
+ # Generate response with function calling
161
+ response = self.client.generate_content(
162
+ gemini_messages,
163
+ tools=gemini_tools,
164
+ generation_config={
165
+ "temperature": self.temperature,
166
+ "max_output_tokens": self.max_tokens
167
+ }
168
+ )
169
+
170
+ # Check if function calls were made
171
+ if response.candidates[0].content.parts:
172
+ first_part = response.candidates[0].content.parts[0]
173
+
174
+ # Check for function call
175
+ if hasattr(first_part, 'function_call') and first_part.function_call:
176
+ function_call = first_part.function_call
177
+ tool_calls = [{
178
+ "name": function_call.name,
179
+ "arguments": dict(function_call.args)
180
+ }]
181
+ logger.info(f"Gemini requested function call: {function_call.name}")
182
+ return LLMResponse(
183
+ content=None,
184
+ tool_calls=tool_calls,
185
+ finish_reason="function_call"
186
+ )
187
+
188
+ # Regular text response
189
+ content = response.text if hasattr(response, 'text') else None
190
+ logger.info("Gemini generated text response")
191
+ return LLMResponse(
192
+ content=content,
193
+ finish_reason="stop"
194
+ )
195
+
196
+ except Exception as e:
197
+ logger.error(f"Gemini API error: {str(e)}")
198
+ raise
199
+
200
+ async def generate_response_with_tool_results(
201
+ self,
202
+ messages: List[Dict[str, str]],
203
+ tool_calls: List[Dict[str, Any]],
204
+ tool_results: List[Dict[str, Any]]
205
+ ) -> LLMResponse:
206
+ """
207
+ Generate a final response after tool execution.
208
+
209
+ Args:
210
+ messages: Original conversation history
211
+ tool_calls: Tool calls that were made
212
+ tool_results: Results from tool execution
213
+
214
+ Returns:
215
+ LLMResponse with final content
216
+ """
217
+ try:
218
+ # Format tool results as a message
219
+ tool_results_text = "\n\n".join([
220
+ f"Tool: {call['name']}\nResult: {result}"
221
+ for call, result in zip(tool_calls, tool_results)
222
+ ])
223
+
224
+ # Add tool results to messages
225
+ messages_with_results = messages + [
226
+ {"role": "assistant", "content": f"I called the following tools:\n{tool_results_text}"},
227
+ {"role": "user", "content": "Based on these tool results, provide a natural language response to the user."}
228
+ ]
229
+
230
+ # Generate final response
231
+ gemini_messages = self._convert_messages_to_gemini_format(messages_with_results, "")
232
+ response = self.client.generate_content(
233
+ gemini_messages,
234
+ generation_config={
235
+ "temperature": self.temperature,
236
+ "max_output_tokens": self.max_tokens
237
+ }
238
+ )
239
+
240
+ content = response.text if hasattr(response, 'text') else None
241
+ logger.info("Gemini generated final response after tool execution")
242
+ return LLMResponse(
243
+ content=content,
244
+ finish_reason="stop"
245
+ )
246
+
247
+ except Exception as e:
248
+ logger.error(f"Gemini API error in tool results: {str(e)}")
249
+ raise
250
+
251
+ async def generate_simple_response(
252
+ self,
253
+ messages: List[Dict[str, str]],
254
+ system_prompt: str
255
+ ) -> LLMResponse:
256
+ """
257
+ Generate a simple response without function calling.
258
+
259
+ Args:
260
+ messages: Conversation history
261
+ system_prompt: System instructions
262
+
263
+ Returns:
264
+ LLMResponse with content
265
+ """
266
+ try:
267
+ gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt)
268
+ response = self.client.generate_content(
269
+ gemini_messages,
270
+ generation_config={
271
+ "temperature": self.temperature,
272
+ "max_output_tokens": self.max_tokens
273
+ }
274
+ )
275
+
276
+ content = response.text if hasattr(response, 'text') else None
277
+ logger.info("Gemini generated simple response")
278
+ return LLMResponse(
279
+ content=content,
280
+ finish_reason="stop"
281
+ )
282
+
283
+ except Exception as e:
284
+ logger.error(f"Gemini API error: {str(e)}")
285
+ raise
src/agent/providers/openrouter.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenRouter Provider Implementation
3
+
4
+ OpenRouter API provider with function calling support.
5
+ Fallback provider for when Gemini rate limits are exceeded.
6
+ Uses OpenAI-compatible API format.
7
+ """
8
+
9
+ import logging
10
+ from typing import List, Dict, Any
11
+ import httpx
12
+
13
+ from .base import LLMProvider, LLMResponse
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OpenRouterProvider(LLMProvider):
19
+ """
20
+ OpenRouter API provider implementation.
21
+
22
+ Features:
23
+ - OpenAI-compatible API
24
+ - Access to multiple free models
25
+ - Function calling support
26
+ - Recommended free model: google/gemini-flash-1.5
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str,
32
+ model: str = "google/gemini-flash-1.5",
33
+ temperature: float = 0.7,
34
+ max_tokens: int = 8192
35
+ ):
36
+ super().__init__(api_key, model, temperature, max_tokens)
37
+ self.base_url = "https://openrouter.ai/api/v1"
38
+ self.headers = {
39
+ "Authorization": f"Bearer {api_key}",
40
+ "Content-Type": "application/json"
41
+ }
42
+ logger.info(f"Initialized OpenRouterProvider with model: {model}")
43
+
44
+ def _convert_tools_to_openai_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
45
+ """
46
+ Convert MCP tool definitions to OpenAI function format.
47
+
48
+ Args:
49
+ tools: MCP tool definitions
50
+
51
+ Returns:
52
+ List of OpenAI-formatted function definitions
53
+ """
54
+ return [
55
+ {
56
+ "type": "function",
57
+ "function": {
58
+ "name": tool["name"],
59
+ "description": tool["description"],
60
+ "parameters": tool["parameters"]
61
+ }
62
+ }
63
+ for tool in tools
64
+ ]
65
+
66
+ async def generate_response_with_tools(
67
+ self,
68
+ messages: List[Dict[str, str]],
69
+ system_prompt: str,
70
+ tools: List[Dict[str, Any]]
71
+ ) -> LLMResponse:
72
+ """
73
+ Generate a response with function calling support.
74
+
75
+ Args:
76
+ messages: Conversation history
77
+ system_prompt: System instructions
78
+ tools: Tool definitions
79
+
80
+ Returns:
81
+ LLMResponse with content and/or tool_calls
82
+ """
83
+ try:
84
+ # Prepare messages with system prompt
85
+ formatted_messages = [{"role": "system", "content": system_prompt}] + messages
86
+
87
+ # Convert tools to OpenAI format
88
+ openai_tools = self._convert_tools_to_openai_format(tools)
89
+
90
+ # Make API request
91
+ async with httpx.AsyncClient(timeout=30.0) as client:
92
+ response = await client.post(
93
+ f"{self.base_url}/chat/completions",
94
+ headers=self.headers,
95
+ json={
96
+ "model": self.model,
97
+ "messages": formatted_messages,
98
+ "tools": openai_tools,
99
+ "temperature": self.temperature,
100
+ "max_tokens": self.max_tokens
101
+ }
102
+ )
103
+ response.raise_for_status()
104
+ data = response.json()
105
+
106
+ # Parse response
107
+ choice = data["choices"][0]
108
+ message = choice["message"]
109
+
110
+ # Check for function calls
111
+ if "tool_calls" in message and message["tool_calls"]:
112
+ tool_calls = [
113
+ {
114
+ "name": tc["function"]["name"],
115
+ "arguments": tc["function"]["arguments"]
116
+ }
117
+ for tc in message["tool_calls"]
118
+ ]
119
+ logger.info(f"OpenRouter requested function calls: {[tc['name'] for tc in tool_calls]}")
120
+ return LLMResponse(
121
+ content=None,
122
+ tool_calls=tool_calls,
123
+ finish_reason=choice.get("finish_reason", "function_call")
124
+ )
125
+
126
+ # Regular text response
127
+ content = message.get("content")
128
+ logger.info("OpenRouter generated text response")
129
+ return LLMResponse(
130
+ content=content,
131
+ finish_reason=choice.get("finish_reason", "stop")
132
+ )
133
+
134
+ except httpx.HTTPStatusError as e:
135
+ logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
136
+ raise
137
+ except Exception as e:
138
+ logger.error(f"OpenRouter API error: {str(e)}")
139
+ raise
140
+
141
+ async def generate_response_with_tool_results(
142
+ self,
143
+ messages: List[Dict[str, str]],
144
+ tool_calls: List[Dict[str, Any]],
145
+ tool_results: List[Dict[str, Any]]
146
+ ) -> LLMResponse:
147
+ """
148
+ Generate a final response after tool execution.
149
+
150
+ Args:
151
+ messages: Original conversation history
152
+ tool_calls: Tool calls that were made
153
+ tool_results: Results from tool execution
154
+
155
+ Returns:
156
+ LLMResponse with final content
157
+ """
158
+ try:
159
+ # Format tool results as messages
160
+ messages_with_results = messages.copy()
161
+
162
+ # Add assistant message with tool calls
163
+ messages_with_results.append({
164
+ "role": "assistant",
165
+ "content": None,
166
+ "tool_calls": [
167
+ {
168
+ "id": f"call_{i}",
169
+ "type": "function",
170
+ "function": {
171
+ "name": call["name"],
172
+ "arguments": str(call["arguments"])
173
+ }
174
+ }
175
+ for i, call in enumerate(tool_calls)
176
+ ]
177
+ })
178
+
179
+ # Add tool result messages
180
+ for i, (call, result) in enumerate(zip(tool_calls, tool_results)):
181
+ messages_with_results.append({
182
+ "role": "tool",
183
+ "tool_call_id": f"call_{i}",
184
+ "content": str(result)
185
+ })
186
+
187
+ # Generate final response
188
+ async with httpx.AsyncClient(timeout=30.0) as client:
189
+ response = await client.post(
190
+ f"{self.base_url}/chat/completions",
191
+ headers=self.headers,
192
+ json={
193
+ "model": self.model,
194
+ "messages": messages_with_results,
195
+ "temperature": self.temperature,
196
+ "max_tokens": self.max_tokens
197
+ }
198
+ )
199
+ response.raise_for_status()
200
+ data = response.json()
201
+
202
+ choice = data["choices"][0]
203
+ content = choice["message"].get("content")
204
+ logger.info("OpenRouter generated final response after tool execution")
205
+ return LLMResponse(
206
+ content=content,
207
+ finish_reason=choice.get("finish_reason", "stop")
208
+ )
209
+
210
+ except httpx.HTTPStatusError as e:
211
+ logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
212
+ raise
213
+ except Exception as e:
214
+ logger.error(f"OpenRouter API error in tool results: {str(e)}")
215
+ raise
216
+
217
+ async def generate_simple_response(
218
+ self,
219
+ messages: List[Dict[str, str]],
220
+ system_prompt: str
221
+ ) -> LLMResponse:
222
+ """
223
+ Generate a simple response without function calling.
224
+
225
+ Args:
226
+ messages: Conversation history
227
+ system_prompt: System instructions
228
+
229
+ Returns:
230
+ LLMResponse with content
231
+ """
232
+ try:
233
+ # Prepare messages with system prompt
234
+ formatted_messages = [{"role": "system", "content": system_prompt}] + messages
235
+
236
+ # Make API request
237
+ async with httpx.AsyncClient(timeout=30.0) as client:
238
+ response = await client.post(
239
+ f"{self.base_url}/chat/completions",
240
+ headers=self.headers,
241
+ json={
242
+ "model": self.model,
243
+ "messages": formatted_messages,
244
+ "temperature": self.temperature,
245
+ "max_tokens": self.max_tokens
246
+ }
247
+ )
248
+ response.raise_for_status()
249
+ data = response.json()
250
+
251
+ choice = data["choices"][0]
252
+ content = choice["message"].get("content")
253
+ logger.info("OpenRouter generated simple response")
254
+ return LLMResponse(
255
+ content=content,
256
+ finish_reason=choice.get("finish_reason", "stop")
257
+ )
258
+
259
+ except httpx.HTTPStatusError as e:
260
+ logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}")
261
+ raise
262
+ except Exception as e:
263
+ logger.error(f"OpenRouter API error: {str(e)}")
264
+ raise
src/api/routes/__pycache__/chat.cpython-313.pyc ADDED
Binary file (10.6 kB). View file
 
src/api/routes/__pycache__/conversations.cpython-313.pyc ADDED
Binary file (10.1 kB). View file
 
src/api/routes/chat.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat API endpoint for AI chatbot."""
2
+ from fastapi import APIRouter, Depends, HTTPException, status
3
+ from sqlmodel import Session
4
+ from typing import Dict, Any
5
+ import logging
6
+ from datetime import datetime
7
+
8
+ from src.core.database import get_session
9
+ from src.core.security import get_current_user
10
+ from src.core.config import settings
11
+ from src.schemas.chat_request import ChatRequest
12
+ from src.schemas.chat_response import ChatResponse
13
+ from src.services.conversation_service import ConversationService
14
+ from src.agent.agent_config import AgentConfiguration
15
+ from src.agent.agent_runner import AgentRunner
16
+ from src.mcp import tool_registry
17
+ from src.core.exceptions import (
18
+ classify_ai_error,
19
+ APIKeyMissingException,
20
+ APIKeyInvalidException
21
+ )
22
+
23
+
24
+ # Configure logging
25
+ logger = logging.getLogger(__name__)
26
+
27
+ router = APIRouter(prefix="/api", tags=["chat"])
28
+
29
+
30
+ def generate_conversation_title(first_user_message: str) -> str:
31
+ """Generate a conversation title from the first user message.
32
+
33
+ Args:
34
+ first_user_message: The first message from the user
35
+
36
+ Returns:
37
+ A title string (max 50 characters)
38
+ """
39
+ # Remove leading/trailing whitespace
40
+ message = first_user_message.strip()
41
+
42
+ # Try to extract the first sentence or first 50 characters
43
+ # Split by common sentence endings
44
+ for delimiter in ['. ', '! ', '? ', '\n']:
45
+ if delimiter in message:
46
+ title = message.split(delimiter)[0]
47
+ break
48
+ else:
49
+ # No sentence delimiter found, use first 50 chars
50
+ title = message[:50]
51
+
52
+ # If title is too short (less than 10 chars), use timestamp-based default
53
+ if len(title) < 10:
54
+ return f"Chat {datetime.now().strftime('%b %d, %I:%M %p')}"
55
+
56
+ # Truncate to 50 characters and add ellipsis if needed
57
+ if len(title) > 50:
58
+ title = title[:47] + "..."
59
+
60
+ return title
61
+
62
+
63
+ @router.post("/{user_id}/chat", response_model=ChatResponse)
64
+ async def chat(
65
+ user_id: int,
66
+ request: ChatRequest,
67
+ db: Session = Depends(get_session),
68
+ current_user: Dict[str, Any] = Depends(get_current_user)
69
+ ) -> ChatResponse:
70
+ """Handle chat messages from users.
71
+
72
+ Args:
73
+ user_id: ID of the user sending the message
74
+ request: ChatRequest containing the user's message
75
+ db: Database session
76
+ current_user: Authenticated user from JWT token
77
+
78
+ Returns:
79
+ ChatResponse containing the AI's response
80
+
81
+ Raises:
82
+ HTTPException 401: If user is not authenticated or user_id doesn't match
83
+ HTTPException 404: If conversation_id is provided but not found
84
+ HTTPException 500: If AI provider fails to generate response
85
+ """
86
+ # Verify user authorization
87
+ if current_user["id"] != user_id:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_401_UNAUTHORIZED,
90
+ detail="Not authorized to access this user's chat"
91
+ )
92
+
93
+ try:
94
+ # Validate request message length
95
+ if not request.message or len(request.message.strip()) == 0:
96
+ raise HTTPException(
97
+ status_code=status.HTTP_400_BAD_REQUEST,
98
+ detail="Message cannot be empty"
99
+ )
100
+
101
+ if len(request.message) > 10000:
102
+ raise HTTPException(
103
+ status_code=status.HTTP_400_BAD_REQUEST,
104
+ detail="Message exceeds maximum length of 10,000 characters"
105
+ )
106
+
107
+ # Initialize services
108
+ conversation_service = ConversationService(db)
109
+
110
+ # Initialize agent configuration from settings
111
+ try:
112
+ agent_config = AgentConfiguration(
113
+ provider=settings.LLM_PROVIDER,
114
+ fallback_provider=settings.FALLBACK_PROVIDER,
115
+ gemini_api_key=settings.GEMINI_API_KEY,
116
+ openrouter_api_key=settings.OPENROUTER_API_KEY,
117
+ cohere_api_key=settings.COHERE_API_KEY,
118
+ temperature=settings.AGENT_TEMPERATURE,
119
+ max_tokens=settings.AGENT_MAX_TOKENS,
120
+ max_messages=settings.CONVERSATION_MAX_MESSAGES,
121
+ max_conversation_tokens=settings.CONVERSATION_MAX_TOKENS
122
+ )
123
+ agent_config.validate()
124
+
125
+ # Create agent runner with tool registry
126
+ agent_runner = AgentRunner(agent_config, tool_registry)
127
+ except ValueError as e:
128
+ logger.error(f"Agent initialization failed: {str(e)}")
129
+ # Check if it's an API key issue
130
+ error_msg = str(e).lower()
131
+ if "api key" in error_msg:
132
+ if "not found" in error_msg or "missing" in error_msg:
133
+ raise APIKeyMissingException(provider=settings.LLM_PROVIDER)
134
+ elif "invalid" in error_msg:
135
+ raise APIKeyInvalidException(provider=settings.LLM_PROVIDER)
136
+ raise HTTPException(
137
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
138
+ detail="AI service is not properly configured. Please contact support."
139
+ )
140
+
141
+ # Get or create conversation
142
+ is_new_conversation = False
143
+ if request.conversation_id:
144
+ conversation = conversation_service.get_conversation(
145
+ request.conversation_id,
146
+ user_id
147
+ )
148
+ if not conversation:
149
+ raise HTTPException(
150
+ status_code=status.HTTP_404_NOT_FOUND,
151
+ detail=f"Conversation {request.conversation_id} not found or you don't have access to it"
152
+ )
153
+ else:
154
+ # Create new conversation with auto-generated title
155
+ try:
156
+ # Generate title from first user message
157
+ title = generate_conversation_title(request.message)
158
+ conversation = conversation_service.create_conversation(
159
+ user_id=user_id,
160
+ title=title
161
+ )
162
+ is_new_conversation = True
163
+ logger.info(f"Created new conversation {conversation.id} with title: {title}")
164
+ except Exception as e:
165
+ logger.error(f"Failed to create conversation: {str(e)}")
166
+ raise HTTPException(
167
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
168
+ detail="Failed to create conversation. Please try again."
169
+ )
170
+
171
+ # Add user message to conversation
172
+ try:
173
+ user_message = conversation_service.add_message(
174
+ conversation_id=conversation.id,
175
+ role="user",
176
+ content=request.message,
177
+ token_count=len(request.message) // 4 # Rough token estimate
178
+ )
179
+ except Exception as e:
180
+ logger.error(f"Failed to save user message: {str(e)}")
181
+ raise HTTPException(
182
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
183
+ detail="Failed to save your message. Please try again."
184
+ )
185
+
186
+ # Get conversation history and format for agent
187
+ history_messages = conversation_service.get_conversation_messages(
188
+ conversation_id=conversation.id
189
+ )
190
+
191
+ # Format messages for agent with trimming
192
+ formatted_messages = conversation_service.format_messages_for_agent(
193
+ messages=history_messages,
194
+ max_messages=agent_config.max_messages,
195
+ max_tokens=agent_config.max_conversation_tokens
196
+ )
197
+
198
+ # Generate AI response with tool calling support
199
+ system_prompt = request.system_prompt or agent_config.system_prompt
200
+
201
+ try:
202
+ agent_result = await agent_runner.execute(
203
+ messages=formatted_messages,
204
+ user_id=user_id, # Inject user context for security
205
+ system_prompt=system_prompt
206
+ )
207
+ except Exception as e:
208
+ # Use classify_ai_error to determine the appropriate exception
209
+ logger.error(f"AI service error for user {user_id}: {str(e)}")
210
+ provider = agent_result.get("provider") if 'agent_result' in locals() else settings.LLM_PROVIDER
211
+ raise classify_ai_error(e, provider=provider)
212
+
213
+ # Add AI response to conversation with tool call metadata
214
+ try:
215
+ # Prepare metadata if tools were used
216
+ tool_metadata = None
217
+ if agent_result.get("tool_calls"):
218
+ # Convert ToolExecutionResult objects to dicts for JSON serialization
219
+ tool_results = agent_result.get("tool_results", [])
220
+ serializable_results = []
221
+ for result in tool_results:
222
+ if hasattr(result, '__dict__'):
223
+ # Convert dataclass/object to dict
224
+ serializable_results.append({
225
+ "success": result.success,
226
+ "data": result.data,
227
+ "message": result.message,
228
+ "error": result.error
229
+ })
230
+ else:
231
+ # Already a dict
232
+ serializable_results.append(result)
233
+
234
+ tool_metadata = {
235
+ "tool_calls": agent_result["tool_calls"],
236
+ "tool_results": serializable_results,
237
+ "provider": agent_result.get("provider")
238
+ }
239
+
240
+ assistant_message = conversation_service.add_message(
241
+ conversation_id=conversation.id,
242
+ role="assistant",
243
+ content=agent_result["content"],
244
+ token_count=len(agent_result["content"]) // 4 # Rough token estimate
245
+ )
246
+
247
+ # Update tool_metadata if tools were used
248
+ if tool_metadata:
249
+ assistant_message.tool_metadata = tool_metadata
250
+ db.add(assistant_message)
251
+ db.commit()
252
+ except Exception as e:
253
+ logger.error(f"Failed to save AI response: {str(e)}")
254
+ # Still return the response even if saving fails
255
+ # User gets the response but it won't be in history
256
+ logger.warning(f"Returning response without saving to database for conversation {conversation.id}")
257
+
258
+ # Log tool usage if any
259
+ if agent_result.get("tool_calls"):
260
+ logger.info(f"Agent used {len(agent_result['tool_calls'])} tools for user {user_id}")
261
+
262
+ # Return response
263
+ return ChatResponse(
264
+ conversation_id=conversation.id,
265
+ message=agent_result["content"],
266
+ role="assistant",
267
+ timestamp=assistant_message.timestamp if 'assistant_message' in locals() else user_message.timestamp,
268
+ token_count=len(agent_result["content"]) // 4,
269
+ model=agent_result.get("provider")
270
+ )
271
+
272
+ except HTTPException:
273
+ # Re-raise HTTP exceptions
274
+ raise
275
+ except Exception as e:
276
+ # Catch-all for unexpected errors
277
+ logger.exception(f"Unexpected error in chat endpoint: {str(e)}")
278
+ raise HTTPException(
279
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
280
+ detail="An unexpected error occurred. Please try again later."
281
+ )
src/api/routes/conversations.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conversations API endpoints for managing chat conversations."""
2
+ from fastapi import APIRouter, Depends, HTTPException, status, Query
3
+ from sqlmodel import Session
4
+ from typing import Dict, Any, List
5
+ import logging
6
+
7
+ from src.core.database import get_session
8
+ from src.core.security import get_current_user
9
+ from src.services.conversation_service import ConversationService
10
+ from src.schemas.conversation import (
11
+ ConversationListResponse,
12
+ ConversationSummary,
13
+ MessageListResponse,
14
+ MessageResponse,
15
+ UpdateConversationRequest,
16
+ UpdateConversationResponse
17
+ )
18
+
19
+ # Configure logging
20
+ logger = logging.getLogger(__name__)
21
+
22
+ router = APIRouter(prefix="/api", tags=["conversations"])
23
+
24
+
25
+ @router.get("/{user_id}/conversations", response_model=ConversationListResponse)
26
+ async def list_conversations(
27
+ user_id: int,
28
+ limit: int = Query(50, ge=1, le=100, description="Maximum number of conversations to return"),
29
+ db: Session = Depends(get_session),
30
+ current_user: Dict[str, Any] = Depends(get_current_user)
31
+ ) -> ConversationListResponse:
32
+ """List all conversations for a user.
33
+
34
+ Args:
35
+ user_id: ID of the user
36
+ limit: Maximum number of conversations to return (default: 50, max: 100)
37
+ db: Database session
38
+ current_user: Authenticated user from JWT token
39
+
40
+ Returns:
41
+ ConversationListResponse with list of conversations
42
+
43
+ Raises:
44
+ HTTPException 401: If user is not authenticated or user_id doesn't match
45
+ """
46
+ # Verify user authorization
47
+ if current_user["id"] != user_id:
48
+ raise HTTPException(
49
+ status_code=status.HTTP_401_UNAUTHORIZED,
50
+ detail="Not authorized to access this user's conversations"
51
+ )
52
+
53
+ try:
54
+ conversation_service = ConversationService(db)
55
+ conversations = conversation_service.get_user_conversations(user_id, limit=limit)
56
+
57
+ # Build conversation summaries with message count and preview
58
+ summaries: List[ConversationSummary] = []
59
+ for conv in conversations:
60
+ # Get messages for this conversation
61
+ messages = conversation_service.get_conversation_messages(conv.id)
62
+ message_count = len(messages)
63
+
64
+ # Get last message preview
65
+ last_message_preview = None
66
+ if messages:
67
+ last_msg = messages[-1]
68
+ # Take first 100 characters of the last message
69
+ last_message_preview = last_msg.content[:100]
70
+ if len(last_msg.content) > 100:
71
+ last_message_preview += "..."
72
+
73
+ summaries.append(ConversationSummary(
74
+ id=conv.id,
75
+ title=conv.title,
76
+ created_at=conv.created_at,
77
+ updated_at=conv.updated_at,
78
+ message_count=message_count,
79
+ last_message_preview=last_message_preview
80
+ ))
81
+
82
+ return ConversationListResponse(
83
+ conversations=summaries,
84
+ total=len(summaries)
85
+ )
86
+
87
+ except Exception as e:
88
+ logger.exception(f"Failed to list conversations for user {user_id}: {str(e)}")
89
+ raise HTTPException(
90
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
91
+ detail="Failed to retrieve conversations. Please try again."
92
+ )
93
+
94
+
95
+ @router.get("/{user_id}/conversations/{conversation_id}/messages", response_model=MessageListResponse)
96
+ async def get_conversation_messages(
97
+ user_id: int,
98
+ conversation_id: int,
99
+ offset: int = Query(0, ge=0, description="Number of messages to skip"),
100
+ limit: int = Query(50, ge=1, le=200, description="Maximum number of messages to return"),
101
+ db: Session = Depends(get_session),
102
+ current_user: Dict[str, Any] = Depends(get_current_user)
103
+ ) -> MessageListResponse:
104
+ """Get message history for a conversation.
105
+
106
+ Args:
107
+ user_id: ID of the user
108
+ conversation_id: ID of the conversation
109
+ offset: Number of messages to skip (for pagination)
110
+ limit: Maximum number of messages to return (default: 50, max: 200)
111
+ db: Database session
112
+ current_user: Authenticated user from JWT token
113
+
114
+ Returns:
115
+ MessageListResponse with list of messages
116
+
117
+ Raises:
118
+ HTTPException 401: If user is not authenticated or user_id doesn't match
119
+ HTTPException 404: If conversation not found or user doesn't have access
120
+ """
121
+ # Verify user authorization
122
+ if current_user["id"] != user_id:
123
+ raise HTTPException(
124
+ status_code=status.HTTP_401_UNAUTHORIZED,
125
+ detail="Not authorized to access this user's conversations"
126
+ )
127
+
128
+ try:
129
+ conversation_service = ConversationService(db)
130
+
131
+ # Verify conversation exists and belongs to user
132
+ conversation = conversation_service.get_conversation(conversation_id, user_id)
133
+ if not conversation:
134
+ raise HTTPException(
135
+ status_code=status.HTTP_404_NOT_FOUND,
136
+ detail=f"Conversation {conversation_id} not found or you don't have access to it"
137
+ )
138
+
139
+ # Get all messages (we'll handle pagination manually)
140
+ all_messages = conversation_service.get_conversation_messages(conversation_id)
141
+ total = len(all_messages)
142
+
143
+ # Apply pagination
144
+ paginated_messages = all_messages[offset:offset + limit]
145
+
146
+ # Convert to response format
147
+ message_responses = [
148
+ MessageResponse(
149
+ id=msg.id,
150
+ role=msg.role,
151
+ content=msg.content,
152
+ timestamp=msg.timestamp,
153
+ token_count=msg.token_count
154
+ )
155
+ for msg in paginated_messages
156
+ ]
157
+
158
+ return MessageListResponse(
159
+ conversation_id=conversation_id,
160
+ messages=message_responses,
161
+ total=total
162
+ )
163
+
164
+ except HTTPException:
165
+ raise
166
+ except Exception as e:
167
+ logger.exception(f"Failed to get messages for conversation {conversation_id}: {str(e)}")
168
+ raise HTTPException(
169
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
170
+ detail="Failed to retrieve messages. Please try again."
171
+ )
172
+
173
+
174
+ @router.patch("/{user_id}/conversations/{conversation_id}", response_model=UpdateConversationResponse)
175
+ async def update_conversation(
176
+ user_id: int,
177
+ conversation_id: int,
178
+ request: UpdateConversationRequest,
179
+ db: Session = Depends(get_session),
180
+ current_user: Dict[str, Any] = Depends(get_current_user)
181
+ ) -> UpdateConversationResponse:
182
+ """Update a conversation's title.
183
+
184
+ Args:
185
+ user_id: ID of the user
186
+ conversation_id: ID of the conversation
187
+ request: UpdateConversationRequest with new title
188
+ db: Database session
189
+ current_user: Authenticated user from JWT token
190
+
191
+ Returns:
192
+ UpdateConversationResponse with updated conversation
193
+
194
+ Raises:
195
+ HTTPException 401: If user is not authenticated or user_id doesn't match
196
+ HTTPException 404: If conversation not found or user doesn't have access
197
+ """
198
+ # Verify user authorization
199
+ if current_user["id"] != user_id:
200
+ raise HTTPException(
201
+ status_code=status.HTTP_401_UNAUTHORIZED,
202
+ detail="Not authorized to access this user's conversations"
203
+ )
204
+
205
+ try:
206
+ conversation_service = ConversationService(db)
207
+
208
+ # Verify conversation exists and belongs to user
209
+ conversation = conversation_service.get_conversation(conversation_id, user_id)
210
+ if not conversation:
211
+ raise HTTPException(
212
+ status_code=status.HTTP_404_NOT_FOUND,
213
+ detail=f"Conversation {conversation_id} not found or you don't have access to it"
214
+ )
215
+
216
+ # Update the title
217
+ from datetime import datetime
218
+ conversation.title = request.title
219
+ conversation.updated_at = datetime.utcnow()
220
+ db.add(conversation)
221
+ db.commit()
222
+ db.refresh(conversation)
223
+
224
+ return UpdateConversationResponse(
225
+ id=conversation.id,
226
+ title=conversation.title,
227
+ updated_at=conversation.updated_at
228
+ )
229
+
230
+ except HTTPException:
231
+ raise
232
+ except Exception as e:
233
+ logger.exception(f"Failed to update conversation {conversation_id}: {str(e)}")
234
+ raise HTTPException(
235
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
236
+ detail="Failed to update conversation. Please try again."
237
+ )
238
+
239
+
240
+ @router.delete("/{user_id}/conversations/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
241
+ async def delete_conversation(
242
+ user_id: int,
243
+ conversation_id: int,
244
+ db: Session = Depends(get_session),
245
+ current_user: Dict[str, Any] = Depends(get_current_user)
246
+ ) -> None:
247
+ """Delete a conversation and all its messages.
248
+
249
+ Args:
250
+ user_id: ID of the user
251
+ conversation_id: ID of the conversation
252
+ db: Database session
253
+ current_user: Authenticated user from JWT token
254
+
255
+ Returns:
256
+ None (204 No Content)
257
+
258
+ Raises:
259
+ HTTPException 401: If user is not authenticated or user_id doesn't match
260
+ HTTPException 404: If conversation not found or user doesn't have access
261
+ """
262
+ # Verify user authorization
263
+ if current_user["id"] != user_id:
264
+ raise HTTPException(
265
+ status_code=status.HTTP_401_UNAUTHORIZED,
266
+ detail="Not authorized to access this user's conversations"
267
+ )
268
+
269
+ try:
270
+ conversation_service = ConversationService(db)
271
+
272
+ # Delete the conversation (service method handles authorization check)
273
+ deleted = conversation_service.delete_conversation(conversation_id, user_id)
274
+
275
+ if not deleted:
276
+ raise HTTPException(
277
+ status_code=status.HTTP_404_NOT_FOUND,
278
+ detail=f"Conversation {conversation_id} not found or you don't have access to it"
279
+ )
280
+
281
+ # Return 204 No Content (no response body)
282
+ return None
283
+
284
+ except HTTPException:
285
+ raise
286
+ except Exception as e:
287
+ logger.exception(f"Failed to delete conversation {conversation_id}: {str(e)}")
288
+ raise HTTPException(
289
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
290
+ detail="Failed to delete conversation. Please try again."
291
+ )
src/core/__pycache__/config.cpython-313.pyc CHANGED
Binary files a/src/core/__pycache__/config.cpython-313.pyc and b/src/core/__pycache__/config.cpython-313.pyc differ
 
src/core/__pycache__/exceptions.cpython-313.pyc ADDED
Binary file (6.86 kB). View file
 
src/core/__pycache__/security.cpython-313.pyc CHANGED
Binary files a/src/core/__pycache__/security.cpython-313.pyc and b/src/core/__pycache__/security.cpython-313.pyc differ
 
src/core/config.py CHANGED
@@ -19,6 +19,21 @@ class Settings(BaseSettings):
19
  JWT_ALGORITHM: str = "HS256"
20
  JWT_EXPIRATION_DAYS: int = 7
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class Config:
23
  env_file = ".env"
24
  case_sensitive = True
 
19
  JWT_ALGORITHM: str = "HS256"
20
  JWT_EXPIRATION_DAYS: int = 7
21
 
22
+ # LLM Provider Configuration
23
+ LLM_PROVIDER: str = "gemini" # Primary provider: gemini, openrouter, cohere
24
+ FALLBACK_PROVIDER: str | None = None # Optional fallback provider
25
+ GEMINI_API_KEY: str | None = None # Required if using Gemini
26
+ OPENROUTER_API_KEY: str | None = None # Required if using OpenRouter
27
+ COHERE_API_KEY: str | None = None # Required if using Cohere
28
+
29
+ # Agent Configuration
30
+ AGENT_TEMPERATURE: float = 0.7 # Sampling temperature (0.0 to 1.0)
31
+ AGENT_MAX_TOKENS: int = 8192 # Maximum tokens in response
32
+
33
+ # Conversation Settings (for free-tier constraints)
34
+ CONVERSATION_MAX_MESSAGES: int = 20 # Maximum messages to keep in history
35
+ CONVERSATION_MAX_TOKENS: int = 8000 # Maximum tokens in conversation history
36
+
37
  class Config:
38
  env_file = ".env"
39
  case_sensitive = True
src/core/exceptions.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exception classes for structured error handling.
3
+ """
4
+
5
+ from typing import Optional
6
+ from fastapi import HTTPException, status
7
+ from src.schemas.error import ErrorCode
8
+
9
+
10
+ class AIProviderException(HTTPException):
11
+ """Base exception for AI provider errors."""
12
+
13
+ def __init__(
14
+ self,
15
+ error_code: str,
16
+ detail: str,
17
+ status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR,
18
+ provider: Optional[str] = None
19
+ ):
20
+ super().__init__(status_code=status_code, detail=detail)
21
+ self.error_code = error_code
22
+ self.source = "AI_PROVIDER"
23
+ self.provider = provider
24
+
25
+
26
+ class RateLimitExceededException(AIProviderException):
27
+ """Exception raised when AI provider rate limit is exceeded."""
28
+
29
+ def __init__(self, provider: Optional[str] = None):
30
+ super().__init__(
31
+ error_code=ErrorCode.RATE_LIMIT_EXCEEDED,
32
+ detail="AI service rate limit exceeded. Please wait a moment and try again.",
33
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
34
+ provider=provider
35
+ )
36
+
37
+
38
+ class APIKeyMissingException(AIProviderException):
39
+ """Exception raised when API key is not configured."""
40
+
41
+ def __init__(self, provider: Optional[str] = None):
42
+ super().__init__(
43
+ error_code=ErrorCode.API_KEY_MISSING,
44
+ detail="AI service is not configured. Please add an API key.",
45
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
46
+ provider=provider
47
+ )
48
+
49
+
50
+ class APIKeyInvalidException(AIProviderException):
51
+ """Exception raised when API key is invalid or expired."""
52
+
53
+ def __init__(self, provider: Optional[str] = None):
54
+ super().__init__(
55
+ error_code=ErrorCode.API_KEY_INVALID,
56
+ detail="Your API key is invalid or expired. Please check your configuration.",
57
+ status_code=status.HTTP_401_UNAUTHORIZED,
58
+ provider=provider
59
+ )
60
+
61
+
62
+ class ProviderUnavailableException(AIProviderException):
63
+ """Exception raised when AI provider is temporarily unavailable."""
64
+
65
+ def __init__(self, provider: Optional[str] = None):
66
+ super().__init__(
67
+ error_code=ErrorCode.PROVIDER_UNAVAILABLE,
68
+ detail="AI service is temporarily unavailable. Please try again later.",
69
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
70
+ provider=provider
71
+ )
72
+
73
+
74
+ class ProviderErrorException(AIProviderException):
75
+ """Exception raised for generic AI provider errors."""
76
+
77
+ def __init__(self, detail: str, provider: Optional[str] = None):
78
+ super().__init__(
79
+ error_code=ErrorCode.PROVIDER_ERROR,
80
+ detail=detail,
81
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
82
+ provider=provider
83
+ )
84
+
85
+
86
+ def classify_ai_error(error: Exception, provider: Optional[str] = None) -> AIProviderException:
87
+ """
88
+ Classify an AI provider error and return appropriate exception.
89
+
90
+ Args:
91
+ error: The original exception from the AI provider
92
+ provider: Name of the AI provider (gemini, openrouter, cohere)
93
+
94
+ Returns:
95
+ Appropriate AIProviderException subclass
96
+ """
97
+ error_message = str(error).lower()
98
+
99
+ # Rate limit errors
100
+ if any(keyword in error_message for keyword in ["rate limit", "429", "quota exceeded", "too many requests"]):
101
+ return RateLimitExceededException(provider=provider)
102
+
103
+ # API key missing errors
104
+ if any(keyword in error_message for keyword in ["api key not found", "api key is required", "missing api key"]):
105
+ return APIKeyMissingException(provider=provider)
106
+
107
+ # API key invalid errors
108
+ if any(keyword in error_message for keyword in [
109
+ "invalid api key", "api key invalid", "unauthorized", "401",
110
+ "authentication failed", "invalid credentials", "api key expired"
111
+ ]):
112
+ return APIKeyInvalidException(provider=provider)
113
+
114
+ # Provider unavailable errors
115
+ if any(keyword in error_message for keyword in [
116
+ "503", "service unavailable", "temporarily unavailable",
117
+ "connection refused", "connection timeout", "timeout"
118
+ ]):
119
+ return ProviderUnavailableException(provider=provider)
120
+
121
+ # Generic provider error
122
+ return ProviderErrorException(
123
+ detail=f"AI service error: {str(error)}",
124
+ provider=provider
125
+ )
src/core/security.py CHANGED
@@ -111,9 +111,13 @@
111
  import jwt
112
  from datetime import datetime, timedelta
113
  from passlib.context import CryptContext
114
- from fastapi import HTTPException, status
 
 
 
115
 
116
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
117
  import hashlib
118
  MAX_BCRYPT_BYTES = 72
119
 
@@ -165,3 +169,50 @@ def verify_jwt_token(token: str, secret: str) -> dict:
165
  raise HTTPException(status_code=401, detail="Token expired")
166
  except jwt.InvalidTokenError:
167
  raise HTTPException(status_code=401, detail="Invalid token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  import jwt
112
  from datetime import datetime, timedelta
113
  from passlib.context import CryptContext
114
+ from fastapi import HTTPException, status, Depends
115
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
116
+ from typing import Dict, Any
117
+ from src.core.config import settings
118
 
119
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
120
+ security = HTTPBearer()
121
  import hashlib
122
  MAX_BCRYPT_BYTES = 72
123
 
 
169
  raise HTTPException(status_code=401, detail="Token expired")
170
  except jwt.InvalidTokenError:
171
  raise HTTPException(status_code=401, detail="Invalid token")
172
+
173
+
174
+ def get_current_user(
175
+ credentials: HTTPAuthorizationCredentials = Depends(security)
176
+ ) -> Dict[str, Any]:
177
+ """
178
+ FastAPI dependency to extract and validate JWT token from Authorization header.
179
+
180
+ Args:
181
+ credentials: HTTP Bearer token credentials from request header
182
+
183
+ Returns:
184
+ Dictionary containing user information from token payload:
185
+ - id: User ID (parsed from 'sub' claim)
186
+ - email: User email
187
+ - iat: Token issued at timestamp
188
+ - exp: Token expiration timestamp
189
+
190
+ Raises:
191
+ HTTPException 401: If token is missing, invalid, or expired
192
+ """
193
+ token = credentials.credentials
194
+
195
+ try:
196
+ payload = verify_jwt_token(token, settings.BETTER_AUTH_SECRET)
197
+
198
+ # Extract user ID from 'sub' claim and convert to integer
199
+ user_id = int(payload.get("sub"))
200
+
201
+ return {
202
+ "id": user_id,
203
+ "email": payload.get("email"),
204
+ "iat": payload.get("iat"),
205
+ "exp": payload.get("exp")
206
+ }
207
+ except ValueError:
208
+ raise HTTPException(
209
+ status_code=status.HTTP_401_UNAUTHORIZED,
210
+ detail="Invalid user ID in token",
211
+ headers={"WWW-Authenticate": "Bearer"}
212
+ )
213
+ except Exception as e:
214
+ raise HTTPException(
215
+ status_code=status.HTTP_401_UNAUTHORIZED,
216
+ detail=f"Authentication failed: {str(e)}",
217
+ headers={"WWW-Authenticate": "Bearer"}
218
+ )
src/main.py CHANGED
@@ -1,13 +1,76 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
3
  from .core.config import settings
4
- from .api.routes import tasks, auth
 
 
 
 
 
 
5
 
6
  app = FastAPI(
7
  title=settings.APP_NAME,
8
  debug=settings.DEBUG
9
  )
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Configure CORS
12
  app.add_middleware(
13
  CORSMiddleware,
@@ -20,6 +83,8 @@ app.add_middleware(
20
  # Register routes
21
  app.include_router(auth.router)
22
  app.include_router(tasks.router)
 
 
23
 
24
 
25
  @app.get("/")
 
1
+ from fastapi import FastAPI, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ import logging
5
  from .core.config import settings
6
+ from .api.routes import tasks, auth, chat, conversations
7
+ from .mcp import register_all_tools
8
+ from .core.exceptions import AIProviderException
9
+ from .schemas.error import ErrorResponse
10
+
11
+ # Configure logging
12
+ logger = logging.getLogger(__name__)
13
 
14
  app = FastAPI(
15
  title=settings.APP_NAME,
16
  debug=settings.DEBUG
17
  )
18
 
19
+
20
+ # Global exception handler for AIProviderException
21
+ @app.exception_handler(AIProviderException)
22
+ async def ai_provider_exception_handler(request: Request, exc: AIProviderException):
23
+ """Convert AIProviderException to structured ErrorResponse."""
24
+ error_response = ErrorResponse(
25
+ error_code=exc.error_code,
26
+ detail=exc.detail,
27
+ source=exc.source,
28
+ provider=exc.provider
29
+ )
30
+ logger.error(
31
+ f"AI Provider Error: {exc.error_code} - {exc.detail} "
32
+ f"(Provider: {exc.provider}, Status: {exc.status_code})"
33
+ )
34
+ return JSONResponse(
35
+ status_code=exc.status_code,
36
+ content=error_response.model_dump()
37
+ )
38
+
39
+
40
+ # Global exception handler for generic HTTPException
41
+ @app.exception_handler(Exception)
42
+ async def generic_exception_handler(request: Request, exc: Exception):
43
+ """Catch-all exception handler for unexpected errors."""
44
+ # Log the full exception for debugging
45
+ logger.exception(f"Unhandled exception: {str(exc)}")
46
+
47
+ # Return structured error response
48
+ error_response = ErrorResponse(
49
+ error_code="INTERNAL_ERROR",
50
+ detail="An unexpected error occurred. Please try again later.",
51
+ source="INTERNAL"
52
+ )
53
+ return JSONResponse(
54
+ status_code=500,
55
+ content=error_response.model_dump()
56
+ )
57
+
58
+
59
+ @app.on_event("startup")
60
+ async def startup_event():
61
+ """Initialize application on startup."""
62
+ logger.info("Starting application initialization...")
63
+
64
+ # Register all MCP tools with the tool registry
65
+ try:
66
+ register_all_tools()
67
+ logger.info("MCP tools registered successfully")
68
+ except Exception as e:
69
+ logger.error(f"Failed to register MCP tools: {str(e)}")
70
+ raise
71
+
72
+ logger.info("Application initialization complete")
73
+
74
  # Configure CORS
75
  app.add_middleware(
76
  CORSMiddleware,
 
83
  # Register routes
84
  app.include_router(auth.router)
85
  app.include_router(tasks.router)
86
+ app.include_router(chat.router)
87
+ app.include_router(conversations.router)
88
 
89
 
90
  @app.get("/")
src/mcp/__init__.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Tools Registration
3
+
4
+ Registers all MCP tools with the global tool registry.
5
+ Each tool is registered with its contract definition (name, description, parameters).
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ from .tool_registry import tool_registry
13
+ from .tools.add_task import add_task
14
+ from .tools.list_tasks import list_tasks
15
+ from .tools.complete_task import complete_task
16
+ from .tools.delete_task import delete_task
17
+ from .tools.update_task import update_task
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def load_tool_contract(tool_name: str) -> dict:
23
+ """
24
+ Load tool contract definition from JSON file.
25
+
26
+ Args:
27
+ tool_name: Name of the tool (e.g., "add_task")
28
+
29
+ Returns:
30
+ Tool contract dictionary
31
+
32
+ Raises:
33
+ FileNotFoundError: If contract file not found
34
+ """
35
+ # Get the project root directory
36
+ current_file = Path(__file__)
37
+ project_root = current_file.parent.parent.parent # backend/src/mcp -> backend
38
+ contract_path = project_root.parent / "specs" / "001-openai-agent-mcp-tools" / "contracts" / f"{tool_name}.json"
39
+
40
+ if not contract_path.exists():
41
+ raise FileNotFoundError(f"Contract file not found: {contract_path}")
42
+
43
+ with open(contract_path, "r") as f:
44
+ return json.load(f)
45
+
46
+
47
+ def register_all_tools():
48
+ """
49
+ Register all MCP tools with the global tool registry.
50
+
51
+ This function should be called during application startup to ensure
52
+ all tools are available to the agent.
53
+ """
54
+ logger.info("Registering MCP tools...")
55
+
56
+ # Register add_task tool
57
+ try:
58
+ add_task_contract = load_tool_contract("add_task")
59
+ tool_registry.register_tool(
60
+ name=add_task_contract["name"],
61
+ description=add_task_contract["description"],
62
+ parameters=add_task_contract["parameters"],
63
+ handler=add_task
64
+ )
65
+ logger.info("Registered tool: add_task")
66
+ except Exception as e:
67
+ logger.error(f"Failed to register add_task tool: {str(e)}")
68
+ raise
69
+
70
+ # Register list_tasks tool
71
+ try:
72
+ list_tasks_contract = load_tool_contract("list_tasks")
73
+ tool_registry.register_tool(
74
+ name=list_tasks_contract["name"],
75
+ description=list_tasks_contract["description"],
76
+ parameters=list_tasks_contract["parameters"],
77
+ handler=list_tasks
78
+ )
79
+ logger.info("Registered tool: list_tasks")
80
+ except Exception as e:
81
+ logger.error(f"Failed to register list_tasks tool: {str(e)}")
82
+ raise
83
+
84
+ # Register complete_task tool
85
+ try:
86
+ complete_task_contract = load_tool_contract("complete_task")
87
+ tool_registry.register_tool(
88
+ name=complete_task_contract["name"],
89
+ description=complete_task_contract["description"],
90
+ parameters=complete_task_contract["parameters"],
91
+ handler=complete_task
92
+ )
93
+ logger.info("Registered tool: complete_task")
94
+ except Exception as e:
95
+ logger.error(f"Failed to register complete_task tool: {str(e)}")
96
+ raise
97
+
98
+ # Register delete_task tool
99
+ try:
100
+ delete_task_contract = load_tool_contract("delete_task")
101
+ tool_registry.register_tool(
102
+ name=delete_task_contract["name"],
103
+ description=delete_task_contract["description"],
104
+ parameters=delete_task_contract["parameters"],
105
+ handler=delete_task
106
+ )
107
+ logger.info("Registered tool: delete_task")
108
+ except Exception as e:
109
+ logger.error(f"Failed to register delete_task tool: {str(e)}")
110
+ raise
111
+
112
+ # Register update_task tool
113
+ try:
114
+ update_task_contract = load_tool_contract("update_task")
115
+ tool_registry.register_tool(
116
+ name=update_task_contract["name"],
117
+ description=update_task_contract["description"],
118
+ parameters=update_task_contract["parameters"],
119
+ handler=update_task
120
+ )
121
+ logger.info("Registered tool: update_task")
122
+ except Exception as e:
123
+ logger.error(f"Failed to register update_task tool: {str(e)}")
124
+ raise
125
+
126
+ logger.info(f"Successfully registered {len(tool_registry.list_tools())} MCP tools")
127
+
128
+
129
+ # Export the global registry instance for use in other modules
130
+ __all__ = ["tool_registry", "register_all_tools"]
src/mcp/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (5.06 kB). View file
 
src/mcp/__pycache__/tool_registry.cpython-313.pyc ADDED
Binary file (5.78 kB). View file
 
src/mcp/tool_registry.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCP Tool Registry
3
+
4
+ Manages registration and execution of MCP tools with user context injection.
5
+ Security: user_id is injected by the backend, never trusted from LLM output.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Callable, Optional
9
+ from dataclasses import dataclass
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class ToolDefinition:
17
+ """Definition of an MCP tool for LLM function calling."""
18
+ name: str
19
+ description: str
20
+ parameters: Dict[str, Any]
21
+
22
+
23
+ @dataclass
24
+ class ToolExecutionResult:
25
+ """Result of executing an MCP tool."""
26
+ success: bool
27
+ data: Optional[Dict[str, Any]] = None
28
+ message: Optional[str] = None
29
+ error: Optional[str] = None
30
+
31
+
32
+ class MCPToolRegistry:
33
+ """
34
+ Registry for MCP tools with user context injection.
35
+
36
+ This class manages tool registration and execution, ensuring that
37
+ user_id is always injected by the backend for security.
38
+ """
39
+
40
+ def __init__(self):
41
+ self._tools: Dict[str, Callable] = {}
42
+ self._tool_definitions: Dict[str, ToolDefinition] = {}
43
+
44
+ def register_tool(
45
+ self,
46
+ name: str,
47
+ description: str,
48
+ parameters: Dict[str, Any],
49
+ handler: Callable
50
+ ) -> None:
51
+ """
52
+ Register an MCP tool with its handler function.
53
+
54
+ Args:
55
+ name: Tool name (e.g., "add_task")
56
+ description: Tool description for LLM
57
+ parameters: JSON schema for tool parameters
58
+ handler: Async function that executes the tool
59
+ """
60
+ self._tools[name] = handler
61
+ self._tool_definitions[name] = ToolDefinition(
62
+ name=name,
63
+ description=description,
64
+ parameters=parameters
65
+ )
66
+ logger.info(f"Registered MCP tool: {name}")
67
+
68
+ def get_tool_definitions(self) -> List[Dict[str, Any]]:
69
+ """
70
+ Get tool definitions in format suitable for LLM function calling.
71
+
72
+ Returns:
73
+ List of tool definitions with name, description, and parameters
74
+ """
75
+ return [
76
+ {
77
+ "name": tool_def.name,
78
+ "description": tool_def.description,
79
+ "parameters": tool_def.parameters
80
+ }
81
+ for tool_def in self._tool_definitions.values()
82
+ ]
83
+
84
+ async def execute_tool(
85
+ self,
86
+ tool_name: str,
87
+ arguments: Dict[str, Any],
88
+ user_id: int
89
+ ) -> ToolExecutionResult:
90
+ """
91
+ Execute an MCP tool with user context injection.
92
+
93
+ SECURITY: user_id is injected by the backend, never from LLM output.
94
+
95
+ Args:
96
+ tool_name: Name of the tool to execute
97
+ arguments: Tool arguments from LLM
98
+ user_id: User ID (injected by backend, not from LLM)
99
+
100
+ Returns:
101
+ ToolExecutionResult with success status and data/error
102
+ """
103
+ if tool_name not in self._tools:
104
+ logger.error(f"Tool not found: {tool_name}")
105
+ return ToolExecutionResult(
106
+ success=False,
107
+ error=f"Tool '{tool_name}' not found"
108
+ )
109
+
110
+ try:
111
+ # Inject user_id into arguments for security
112
+ arguments_with_context = {**arguments, "user_id": user_id}
113
+
114
+ logger.info(f"Executing tool: {tool_name} for user: {user_id}")
115
+
116
+ # Execute the tool handler
117
+ handler = self._tools[tool_name]
118
+ result = await handler(**arguments_with_context)
119
+
120
+ logger.info(f"Tool execution successful: {tool_name}")
121
+ return result
122
+
123
+ except Exception as e:
124
+ logger.error(f"Tool execution failed: {tool_name} - {str(e)}")
125
+ return ToolExecutionResult(
126
+ success=False,
127
+ error=f"Tool execution failed: {str(e)}"
128
+ )
129
+
130
+ def list_tools(self) -> List[str]:
131
+ """Get list of registered tool names."""
132
+ return list(self._tools.keys())
133
+
134
+ def has_tool(self, tool_name: str) -> bool:
135
+ """Check if a tool is registered."""
136
+ return tool_name in self._tools
137
+
138
+
139
+ # Global registry instance
140
+ tool_registry = MCPToolRegistry()