diff --git a/.dockerignore b/.dockerignore index 549d4e3e0296d7c1ff99155c80f382cfcac5244e..f7a4ebcd15d9c63500fcd74f77a23730fbaa4dcc 100644 --- a/.dockerignore +++ b/.dockerignore @@ -43,3 +43,21 @@ alembic/versions/*.pyc # Documentation docs/_build/ +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd + +# Environment +.env + +# Alembic cache +alembic/versions/__pycache__/ + +# Node +node_modules/ +package-lock.json + +# Temp / AI tools +tmpclaude-* diff --git a/.env b/.env index 209239a4a5c1f65c7ad7fb01f0648e419fea79b6..28f9d5ace2c61a380a133c4ce21b87a73b199778 100644 --- a/.env +++ b/.env @@ -2,6 +2,7 @@ # For local PostgreSQL: postgresql://user:password@localhost:5432/todo_db # For Neon: Use your Neon connection string from the dashboard DATABASE_URL=postgresql://neondb_owner:npg_MmFvJBHT8Y0k@ep-silent-thunder-ab0rbvrp-pooler.eu-west-2.aws.neon.tech/neondb?sslmode=require&channel_binding=require + # Application Settings APP_NAME=Task CRUD API DEBUG=True @@ -11,3 +12,18 @@ CORS_ORIGINS=http://localhost:3000 BETTER_AUTH_SECRET=zMdW1P03wJvWJnLKzQ8YYO26vHeinqmR JWT_ALGORITHM=HS256 JWT_EXPIRATION_DAYS=7 + +# LLM Provider Configuration +LLM_PROVIDER=gemini +# FALLBACK_PROVIDER=openrouter +GEMINI_API_KEY=AIzaSyCAlcHZxp5ELh1GqJwKqBLQziUNi0vnobU +# OPENROUTER_API_KEY=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 +COHERE_API_KEY= + +# Agent Configuration +AGENT_TEMPERATURE=0.7 +AGENT_MAX_TOKENS=500 + +# Conversation Settings +CONVERSATION_MAX_MESSAGES=20 +CONVERSATION_MAX_TOKENS=2000 diff --git a/.env.example b/.env.example index e9232d3d85382220f94f2167488f581a64add4ae..c3dc22cf2a9d1571b2489e01862d8c10dd440e5c 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,32 @@ APP_NAME=Task CRUD API DEBUG=True CORS_ORIGINS=http://localhost:3000 -# Authentication (Placeholder for Spec 2) -# JWT_SECRET=your-secret-key-here -# JWT_ALGORITHM=HS256 -# JWT_EXPIRATION_MINUTES=1440 +# Authentication +BETTER_AUTH_SECRET=your-secret-key-here-min-32-characters +JWT_ALGORITHM=HS256 +JWT_EXPIRATION_DAYS=7 + +# LLM Provider Configuration +# Primary provider: gemini, openrouter, cohere +LLM_PROVIDER=gemini + +# Optional fallback provider (recommended for production) +FALLBACK_PROVIDER=openrouter + +# API Keys (provide at least one for your primary provider) +GEMINI_API_KEY=your-gemini-api-key-here +OPENROUTER_API_KEY=your-openrouter-api-key-here +COHERE_API_KEY=your-cohere-api-key-here + +# Agent Configuration +AGENT_TEMPERATURE=0.7 +AGENT_MAX_TOKENS=8192 + +# Conversation Settings (for free-tier constraints) +CONVERSATION_MAX_MESSAGES=20 +CONVERSATION_MAX_TOKENS=8000 + +# How to get API keys: +# - Gemini: https://makersuite.google.com/app/apikey (free, no credit card required) +# - OpenRouter: https://openrouter.ai/ (free models available) +# - Cohere: https://cohere.com/ (trial only, not recommended for production) diff --git a/README.md b/README.md index 919e8b1bdd2ff87b38cf36b4fc4b0974d4e9e2ca..9c751452aacb31ec2c8f470be750a8bb776034e1 100644 --- a/README.md +++ b/README.md @@ -10,24 +10,227 @@ license: mit # TaskFlow Backend API -FastAPI backend for TaskFlow task management application. +FastAPI backend for TaskFlow task management application with AI chatbot integration. ## Features -- User authentication with JWT +- User authentication with JWT and Better Auth - Task CRUD operations +- **AI Chatbot Assistant** - Conversational AI for task management - PostgreSQL database with SQLModel ORM - RESTful API design +- Multi-turn conversation support with context management +- Intent recognition for todo-related requests + +## Tech Stack + +- **Framework**: FastAPI 0.104.1 +- **ORM**: SQLModel 0.0.14 +- **Database**: PostgreSQL (Neon Serverless) +- **Authentication**: Better Auth + JWT +- **AI Provider**: Google Gemini (gemini-pro) +- **Migrations**: Alembic 1.13.0 ## Environment Variables -Configure these in your Space settings: +Configure these in your `.env` file: + +### Database +- `DATABASE_URL`: PostgreSQL connection string (Neon or local) + +### Application +- `APP_NAME`: Application name (default: "Task CRUD API") +- `DEBUG`: Debug mode (default: True) +- `CORS_ORIGINS`: Allowed CORS origins (default: "http://localhost:3000") + +### Authentication +- `BETTER_AUTH_SECRET`: Secret key for Better Auth (required) +- `JWT_ALGORITHM`: JWT algorithm (default: "HS256") +- `JWT_EXPIRATION_DAYS`: Token expiration in days (default: 7) + +### AI Provider Configuration +- `AI_PROVIDER`: AI provider to use (default: "gemini") +- `GEMINI_API_KEY`: Google Gemini API key (required if using Gemini) +- `OPENROUTER_API_KEY`: OpenRouter API key (optional) +- `COHERE_API_KEY`: Cohere API key (optional) + +### Conversation Settings +- `MAX_CONVERSATION_MESSAGES`: Maximum messages to keep in history (default: 20) +- `MAX_CONVERSATION_TOKENS`: Maximum tokens to keep in history (default: 8000) + +## Setup Instructions + +### 1. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 2. Configure Environment + +Create a `.env` file in the `backend/` directory: -- `DATABASE_URL`: PostgreSQL connection string -- `SECRET_KEY`: JWT secret key (generate a secure random string) -- `ALGORITHM`: JWT algorithm (default: HS256) -- `ACCESS_TOKEN_EXPIRE_MINUTES`: Token expiration time (default: 30) +```env +# Database +DATABASE_URL=postgresql://user:password@localhost:5432/todo_db + +# Application +APP_NAME=Task CRUD API +DEBUG=True +CORS_ORIGINS=http://localhost:3000 + +# Authentication +BETTER_AUTH_SECRET=your_secret_key_here +JWT_ALGORITHM=HS256 +JWT_EXPIRATION_DAYS=7 + +# AI Provider +AI_PROVIDER=gemini +GEMINI_API_KEY=your_gemini_api_key_here + +# Conversation Settings +MAX_CONVERSATION_MESSAGES=20 +MAX_CONVERSATION_TOKENS=8000 +``` + +### 3. Run Database Migrations + +```bash +alembic upgrade head +``` + +### 4. Start the Server + +```bash +uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 +``` + +The API will be available at `http://localhost:8000` ## API Documentation -Once deployed, visit `/docs` for interactive API documentation. +Once running, visit: +- **Interactive Docs**: `http://localhost:8000/docs` +- **ReDoc**: `http://localhost:8000/redoc` + +## API Endpoints + +### Authentication +- `POST /api/auth/signup` - Register new user +- `POST /api/auth/login` - Login user + +### Tasks +- `GET /api/{user_id}/tasks` - Get all tasks for user +- `POST /api/{user_id}/tasks` - Create new task +- `GET /api/{user_id}/tasks/{task_id}` - Get specific task +- `PUT /api/{user_id}/tasks/{task_id}` - Update task +- `DELETE /api/{user_id}/tasks/{task_id}` - Delete task + +### AI Chat (New in Phase 1) +- `POST /api/{user_id}/chat` - Send message to AI assistant + +#### Chat Request Body +```json +{ + "message": "Can you help me organize my tasks?", + "conversation_id": 123, // Optional: null for new conversation + "temperature": 0.7 // Optional: 0.0 to 1.0 +} +``` + +#### Chat Response +```json +{ + "conversation_id": 123, + "message": "I'd be happy to help you organize your tasks!", + "role": "assistant", + "timestamp": "2026-01-14T10:30:00Z", + "token_count": 25, + "model": "gemini-pro" +} +``` + +## AI Chatbot Features + +### Phase 1 (Current) +- ✅ Natural conversation with AI assistant +- ✅ Multi-turn conversations with context retention +- ✅ Intent recognition for todo-related requests +- ✅ Conversation history persistence +- ✅ Automatic history trimming (20 messages / 8000 tokens) +- ✅ Free-tier AI provider support (Gemini) + +### Phase 2 (Coming Soon) +- 🔄 MCP tools for task CRUD operations +- 🔄 AI can directly create, update, and delete tasks +- 🔄 Natural language task management + +## Error Handling + +The API returns standard HTTP status codes: + +- `200 OK` - Request successful +- `400 Bad Request` - Invalid request data +- `401 Unauthorized` - Authentication required or failed +- `404 Not Found` - Resource not found +- `429 Too Many Requests` - Rate limit exceeded +- `500 Internal Server Error` - Server error + +## Database Schema + +### Users Table +- `id`: Primary key +- `email`: Unique email address +- `name`: User's name +- `password`: Hashed password +- `created_at`, `updated_at`: Timestamps + +### Tasks Table +- `id`: Primary key +- `user_id`: Foreign key to users +- `title`: Task title +- `description`: Task description +- `completed`: Boolean status +- `created_at`, `updated_at`: Timestamps + +### Conversation Table (New) +- `id`: Primary key +- `user_id`: Foreign key to users +- `title`: Conversation title +- `created_at`, `updated_at`: Timestamps + +### Message Table (New) +- `id`: Primary key +- `conversation_id`: Foreign key to conversation +- `role`: "user" or "assistant" +- `content`: Message text +- `timestamp`: Message timestamp +- `token_count`: Token count for the message + +## Development + +### Running Tests +```bash +pytest +``` + +### Database Migrations + +Create a new migration: +```bash +alembic revision -m "description" +``` + +Apply migrations: +```bash +alembic upgrade head +``` + +Rollback migration: +```bash +alembic downgrade -1 +``` + +## License + +MIT License diff --git a/alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py b/alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py new file mode 100644 index 0000000000000000000000000000000000000000..f88546201073b7454497440f8c1e152af7c29d0d --- /dev/null +++ b/alembic/versions/20260114_1044_48b10b49730f_add_conversation_and_message_tables.py @@ -0,0 +1,59 @@ +"""add_conversation_and_message_tables + +Revision ID: 48b10b49730f +Revises: 002_add_user_password +Create Date: 2026-01-14 10:44:27.010796 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '48b10b49730f' +down_revision = '002_add_user_password' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create conversation table + op.create_table( + 'conversation', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('title', sa.String(length=255), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('ix_conversation_user_id', 'conversation', ['user_id']) + op.create_index('ix_conversation_created_at', 'conversation', ['created_at']) + + # Create message table + op.create_table( + 'message', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('conversation_id', sa.Integer(), nullable=False), + sa.Column('role', sa.String(length=50), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('timestamp', sa.DateTime(), server_default=sa.text('now()'), nullable=False), + sa.Column('token_count', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['conversation_id'], ['conversation.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('ix_message_conversation_id', 'message', ['conversation_id']) + op.create_index('ix_message_timestamp', 'message', ['timestamp']) + + +def downgrade() -> None: + # Drop message table first (due to foreign key dependency) + op.drop_index('ix_message_timestamp', table_name='message') + op.drop_index('ix_message_conversation_id', table_name='message') + op.drop_table('message') + + # Drop conversation table + op.drop_index('ix_conversation_created_at', table_name='conversation') + op.drop_index('ix_conversation_user_id', table_name='conversation') + op.drop_table('conversation') diff --git a/alembic/versions/20260114_1115_37ca2e18468d_description.py b/alembic/versions/20260114_1115_37ca2e18468d_description.py new file mode 100644 index 0000000000000000000000000000000000000000..8d607c6c4ceb552ed964eb8a5563144fa97831fa --- /dev/null +++ b/alembic/versions/20260114_1115_37ca2e18468d_description.py @@ -0,0 +1,24 @@ +"""description + +Revision ID: 37ca2e18468d +Revises: 48b10b49730f +Create Date: 2026-01-14 11:15:03.691055 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '37ca2e18468d' +down_revision = '48b10b49730f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py b/alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb23aa4266ab97bb1e66706d3257b3a4d4e3c2b --- /dev/null +++ b/alembic/versions/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.py @@ -0,0 +1,27 @@ +"""add metadata column to message table + +Revision ID: a3c44bf7ddcb +Revises: 37ca2e18468d +Create Date: 2026-01-14 17:02:51.060200 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = 'a3c44bf7ddcb' +down_revision = '37ca2e18468d' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add metadata column to message table + op.add_column('message', sa.Column('metadata', postgresql.JSON(astext_type=sa.Text()), nullable=True)) + + +def downgrade() -> None: + # Remove metadata column from message table + op.drop_column('message', 'metadata') diff --git a/alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py b/alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py new file mode 100644 index 0000000000000000000000000000000000000000..e963c45440bf6dd24dfe5be04ef88e74b1ec09c5 --- /dev/null +++ b/alembic/versions/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.py @@ -0,0 +1,26 @@ +"""rename metadata to tool_metadata in message table + +Revision ID: e8275e6c143c +Revises: a3c44bf7ddcb +Create Date: 2026-01-14 17:12:53.740315 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e8275e6c143c' +down_revision = 'a3c44bf7ddcb' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Rename metadata column to tool_metadata + op.alter_column('message', 'metadata', new_column_name='tool_metadata') + + +def downgrade() -> None: + # Rename tool_metadata column back to metadata + op.alter_column('message', 'tool_metadata', new_column_name='metadata') diff --git a/alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py b/alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py new file mode 100644 index 0000000000000000000000000000000000000000..60456009adfc9c5a8ee605db0eafb454f008d085 --- /dev/null +++ b/alembic/versions/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.py @@ -0,0 +1,28 @@ +"""add due_date and priority to task table + +Revision ID: d34db62bd406 +Revises: e8275e6c143c +Create Date: 2026-01-14 19:00:45.426280 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd34db62bd406' +down_revision = 'e8275e6c143c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add due_date and priority columns to tasks table + op.add_column('tasks', sa.Column('due_date', sa.Date(), nullable=True)) + op.add_column('tasks', sa.Column('priority', sa.String(length=20), nullable=False, server_default='medium')) + + +def downgrade() -> None: + # Remove due_date and priority columns from tasks table + op.drop_column('tasks', 'priority') + op.drop_column('tasks', 'due_date') diff --git a/alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..314f72085e13899bc44c02d0f602a7358615ee55 Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1044_48b10b49730f_add_conversation_and_message_tables.cpython-313.pyc differ diff --git a/alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee060fe9897bbcd1300881170de7dce3fe368573 Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1115_37ca2e18468d_description.cpython-313.pyc differ diff --git a/alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7879111c4b5a6e196a19da59ee58ccb8db7b741 Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1659_84d7d00c71ef_add_metadata_to_message_table.cpython-313.pyc differ diff --git a/alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..199941856a26d00c1cd6997a5dc322d41317081d Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1702_a3c44bf7ddcb_add_metadata_column_to_message_table.cpython-313.pyc differ diff --git a/alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deb4517d3f4a100ad9055757bbc0cbefececfea3 Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1712_e8275e6c143c_rename_metadata_to_tool_metadata_in_.cpython-313.pyc differ diff --git a/alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc b/alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac5c4a2263d764b3dd8188eb50d6126f7855e96 Binary files /dev/null and b/alembic/versions/__pycache__/20260114_1900_d34db62bd406_add_due_date_and_priority_to_task_table.cpython-313.pyc differ diff --git a/alembic/versions/tmpclaude-82e1-cwd b/alembic/versions/tmpclaude-82e1-cwd new file mode 100644 index 0000000000000000000000000000000000000000..ee300cbd045cd025aaefca37c24b80d2c87e0796 --- /dev/null +++ b/alembic/versions/tmpclaude-82e1-cwd @@ -0,0 +1 @@ +/d/Agentic_ai_learning/hacathoon_2/evolution-of-todo/phase-2-full-stack-web-app/backend/alembic/versions diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..dfb18f1156fa38a107aad43c8d0dbe59fe56624e --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "backend", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/requirements.txt b/requirements.txt index c2286fe76002f809a53e6b02d7a67217139f395a..7622c6d729900dc6d6bc4bea8b02227255dfc305 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,10 @@ passlib[bcrypt]==1.7.4 python-multipart==0.0.6 # Add this line (or replace if bcrypt is already listed) -bcrypt==4.3.0 # Last stable version before the 5.0 break \ No newline at end of file +bcrypt==4.3.0 # Last stable version before the 5.0 break + +# AI Chatbot dependencies +google-generativeai==0.3.2 # Gemini API client +tiktoken==0.5.2 # Token counting (optional) +mcp==1.20.0 # Official MCP SDK for tool server +cohere==4.37 # Cohere API client (optional provider) \ No newline at end of file diff --git a/src/__pycache__/main.cpython-313.pyc b/src/__pycache__/main.cpython-313.pyc index 0f7532570323f62653a42c188ab9572e7551355f..b69422dcacfa1435badd5334d2739147b2a1f2d2 100644 Binary files a/src/__pycache__/main.cpython-313.pyc and b/src/__pycache__/main.cpython-313.pyc differ diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/agent/__pycache__/__init__.cpython-313.pyc b/src/agent/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c479bdb9c2913d55ba1bc502edfc7842cf51718 Binary files /dev/null and b/src/agent/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/agent/__pycache__/agent_config.cpython-313.pyc b/src/agent/__pycache__/agent_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807da797003203d4ec51c749e32775619c2a22d6 Binary files /dev/null and b/src/agent/__pycache__/agent_config.cpython-313.pyc differ diff --git a/src/agent/__pycache__/agent_runner.cpython-313.pyc b/src/agent/__pycache__/agent_runner.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685b75855a69aa2eb71c3aeba395245bb467c899 Binary files /dev/null and b/src/agent/__pycache__/agent_runner.cpython-313.pyc differ diff --git a/src/agent/agent_config.py b/src/agent/agent_config.py new file mode 100644 index 0000000000000000000000000000000000000000..851b32732f47ac218b90a9155a3c46ab1658d456 --- /dev/null +++ b/src/agent/agent_config.py @@ -0,0 +1,124 @@ +""" +Agent Configuration + +Configuration dataclass for agent behavior and LLM provider settings. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class AgentConfiguration: + """ + Configuration for the AI agent. + + This dataclass defines all configurable parameters for agent behavior, + including LLM provider settings, conversation limits, and system prompts. + """ + + # Provider settings + provider: str = "gemini" # Options: gemini, openrouter, cohere + fallback_provider: Optional[str] = None # Optional fallback provider + model: Optional[str] = None # Model name (provider-specific) + + # API keys (loaded from environment) + gemini_api_key: Optional[str] = None + openrouter_api_key: Optional[str] = None + cohere_api_key: Optional[str] = None + + # Generation parameters + temperature: float = 0.7 # Sampling temperature (0.0 to 1.0) + max_tokens: int = 8192 # Maximum tokens in response + + # Conversation history limits (for free-tier constraints) + max_messages: int = 20 # Maximum messages to keep in history + max_conversation_tokens: int = 8000 # Maximum tokens in conversation history + + # System prompt + system_prompt: str = """You are a helpful AI assistant for managing tasks. +You can help users create, view, complete, update, and delete tasks using natural language. + +Available tools: +- add_task: Create a new task +- list_tasks: View all tasks (with optional filtering) +- complete_task: Mark a task as completed +- delete_task: Remove a task +- update_task: Modify task properties + +Always respond in a friendly, conversational manner and confirm actions taken.""" + + # Retry settings + max_retries: int = 3 # Maximum retries on rate limit errors + retry_delay: float = 1.0 # Delay between retries (seconds) + + def get_provider_api_key(self, provider_name: str) -> Optional[str]: + """ + Get API key for a specific provider. + + Args: + provider_name: Provider name (gemini, openrouter, cohere) + + Returns: + API key or None if not configured + """ + if provider_name == "gemini": + return self.gemini_api_key + elif provider_name == "openrouter": + return self.openrouter_api_key + elif provider_name == "cohere": + return self.cohere_api_key + return None + + def get_provider_model(self, provider_name: str) -> str: + """ + Get default model for a specific provider. + + Args: + provider_name: Provider name + + Returns: + Model identifier + """ + if self.model: + return self.model + + # Default models per provider + defaults = { + "gemini": "gemini-flash-latest", + "openrouter": "google/gemini-flash-1.5", + "cohere": "command-r-plus" + } + return defaults.get(provider_name, "gemini-flash-latest") + + def validate(self) -> bool: + """ + Validate configuration. + + Returns: + True if configuration is valid + + Raises: + ValueError: If configuration is invalid + """ + # Check primary provider has API key + primary_key = self.get_provider_api_key(self.provider) + if not primary_key: + raise ValueError(f"API key not configured for primary provider: {self.provider}") + + # Validate temperature range + if not 0.0 <= self.temperature <= 1.0: + raise ValueError(f"Temperature must be between 0.0 and 1.0, got: {self.temperature}") + + # Validate max_tokens + if self.max_tokens <= 0: + raise ValueError(f"max_tokens must be positive, got: {self.max_tokens}") + + # Validate conversation limits + if self.max_messages <= 0: + raise ValueError(f"max_messages must be positive, got: {self.max_messages}") + + if self.max_conversation_tokens <= 0: + raise ValueError(f"max_conversation_tokens must be positive, got: {self.max_conversation_tokens}") + + return True diff --git a/src/agent/agent_runner.py b/src/agent/agent_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..ded29fb9765fd82317dee7bdba8655ea505e5521 --- /dev/null +++ b/src/agent/agent_runner.py @@ -0,0 +1,281 @@ +""" +Agent Runner + +Core orchestrator for AI agent execution with tool calling support. +Manages the full request cycle: LLM generation → tool execution → final response. +""" + +import logging +from typing import List, Dict, Any, Optional +import asyncio + +from .agent_config import AgentConfiguration +from .providers.base import LLMProvider +from .providers.gemini import GeminiProvider +from .providers.openrouter import OpenRouterProvider +from .providers.cohere import CohereProvider +from ..mcp.tool_registry import MCPToolRegistry, ToolExecutionResult + +logger = logging.getLogger(__name__) + + +class AgentRunner: + """ + Agent execution orchestrator with tool calling support. + + This class manages the full agent request cycle: + 1. Generate LLM response with tool definitions + 2. If tool calls requested, execute tools with user context injection + 3. Generate final response with tool results + 4. Handle rate limiting with fallback providers + """ + + def __init__(self, config: AgentConfiguration, tool_registry: MCPToolRegistry): + """ + Initialize the agent runner. + + Args: + config: Agent configuration + tool_registry: MCP tool registry + """ + self.config = config + self.tool_registry = tool_registry + self.primary_provider = self._create_provider(config.provider) + self.fallback_provider = None + + if config.fallback_provider: + self.fallback_provider = self._create_provider(config.fallback_provider) + + logger.info(f"Initialized AgentRunner with provider: {config.provider}") + + def _create_provider(self, provider_name: str) -> LLMProvider: + """ + Create an LLM provider instance. + + Args: + provider_name: Provider name (gemini, openrouter, cohere) + + Returns: + LLMProvider instance + + Raises: + ValueError: If provider is not supported or API key is missing + """ + api_key = self.config.get_provider_api_key(provider_name) + if not api_key: + raise ValueError(f"API key not configured for provider: {provider_name}") + + model = self.config.get_provider_model(provider_name) + + if provider_name == "gemini": + return GeminiProvider( + api_key=api_key, + model=model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens + ) + elif provider_name == "openrouter": + return OpenRouterProvider( + api_key=api_key, + model=model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens + ) + elif provider_name == "cohere": + return CohereProvider( + api_key=api_key, + model=model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens + ) + else: + raise ValueError(f"Unsupported provider: {provider_name}") + + async def execute( + self, + messages: List[Dict[str, str]], + user_id: int, + system_prompt: Optional[str] = None + ) -> Dict[str, Any]: + """ + Execute agent request with tool calling support. + + SECURITY: user_id is injected by backend, never from LLM output. + + Args: + messages: Conversation history [{"role": "user", "content": "..."}] + user_id: User ID (injected by backend for security) + system_prompt: Optional system prompt (uses config default if not provided) + + Returns: + Dict with response content and metadata + """ + prompt = system_prompt or self.config.system_prompt + provider = self.primary_provider + + try: + # Get tool definitions + tool_definitions = self.tool_registry.get_tool_definitions() + + logger.info(f"Executing agent for user {user_id} with {len(tool_definitions)} tools") + + # Generate initial response with tool definitions + response = await provider.generate_response_with_tools( + messages=messages, + system_prompt=prompt, + tools=tool_definitions + ) + + # Check if tool calls were requested + if response.tool_calls: + logger.info(f"Agent requested {len(response.tool_calls)} tool calls") + + # Execute all tool calls + tool_results = [] + for tool_call in response.tool_calls: + result = await self.tool_registry.execute_tool( + tool_name=tool_call["name"], + arguments=tool_call["arguments"], + user_id=user_id # Inject user context for security + ) + tool_results.append(result) + + # Generate final response with tool results + final_response = await provider.generate_response_with_tool_results( + messages=messages, + tool_calls=response.tool_calls, + tool_results=tool_results + ) + + return { + "content": final_response.content, + "tool_calls": response.tool_calls, + "tool_results": tool_results, + "provider": provider.get_provider_name() + } + + # No tool calls, return direct response + logger.info("Agent generated direct response (no tool calls)") + return { + "content": response.content, + "tool_calls": None, + "tool_results": None, + "provider": provider.get_provider_name() + } + + except Exception as e: + logger.error(f"Agent execution failed with primary provider: {str(e)}") + + # Try fallback provider if configured + if self.fallback_provider: + logger.info("Attempting fallback provider") + try: + return await self._execute_with_provider( + provider=self.fallback_provider, + messages=messages, + user_id=user_id, + system_prompt=prompt + ) + except Exception as fallback_error: + logger.error(f"Fallback provider also failed: {str(fallback_error)}") + raise + + raise + + async def _execute_with_provider( + self, + provider: LLMProvider, + messages: List[Dict[str, str]], + user_id: int, + system_prompt: str + ) -> Dict[str, Any]: + """ + Execute agent request with a specific provider. + + Args: + provider: LLM provider to use + messages: Conversation history + user_id: User ID + system_prompt: System prompt + + Returns: + Dict with response content and metadata + """ + tool_definitions = self.tool_registry.get_tool_definitions() + + # Generate initial response + response = await provider.generate_response_with_tools( + messages=messages, + system_prompt=system_prompt, + tools=tool_definitions + ) + + # Handle tool calls + if response.tool_calls: + tool_results = [] + for tool_call in response.tool_calls: + result = await self.tool_registry.execute_tool( + tool_name=tool_call["name"], + arguments=tool_call["arguments"], + user_id=user_id + ) + tool_results.append(result) + + final_response = await provider.generate_response_with_tool_results( + messages=messages, + tool_calls=response.tool_calls, + tool_results=tool_results + ) + + return { + "content": final_response.content, + "tool_calls": response.tool_calls, + "tool_results": tool_results, + "provider": provider.get_provider_name() + } + + return { + "content": response.content, + "tool_calls": None, + "tool_results": None, + "provider": provider.get_provider_name() + } + + async def execute_simple( + self, + messages: List[Dict[str, str]], + system_prompt: Optional[str] = None + ) -> str: + """ + Execute a simple agent request without tool calling. + + Args: + messages: Conversation history + system_prompt: Optional system prompt + + Returns: + Response content as string + """ + prompt = system_prompt or self.config.system_prompt + provider = self.primary_provider + + try: + response = await provider.generate_simple_response( + messages=messages, + system_prompt=prompt + ) + return response.content or "" + + except Exception as e: + logger.error(f"Simple execution failed: {str(e)}") + + # Try fallback provider + if self.fallback_provider: + logger.info("Attempting fallback provider for simple execution") + response = await self.fallback_provider.generate_simple_response( + messages=messages, + system_prompt=prompt + ) + return response.content or "" + + raise diff --git a/src/agent/providers/__init__.py b/src/agent/providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/agent/providers/__pycache__/__init__.cpython-313.pyc b/src/agent/providers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a55799c8cdd4e917f933b2161749cc5eff41c4cd Binary files /dev/null and b/src/agent/providers/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/agent/providers/__pycache__/base.cpython-313.pyc b/src/agent/providers/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c940756793e9643388f04ac66b2afbd65327e684 Binary files /dev/null and b/src/agent/providers/__pycache__/base.cpython-313.pyc differ diff --git a/src/agent/providers/__pycache__/cohere.cpython-313.pyc b/src/agent/providers/__pycache__/cohere.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96c4af19976f8b30a2b59ec2218698e6f45f5dab Binary files /dev/null and b/src/agent/providers/__pycache__/cohere.cpython-313.pyc differ diff --git a/src/agent/providers/__pycache__/gemini.cpython-313.pyc b/src/agent/providers/__pycache__/gemini.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e47346e158a2a0638995aa21dfb1b2e87d060e Binary files /dev/null and b/src/agent/providers/__pycache__/gemini.cpython-313.pyc differ diff --git a/src/agent/providers/__pycache__/openrouter.cpython-313.pyc b/src/agent/providers/__pycache__/openrouter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac919df463ee74d12e465f27664cd1dd00123cf Binary files /dev/null and b/src/agent/providers/__pycache__/openrouter.cpython-313.pyc differ diff --git a/src/agent/providers/base.py b/src/agent/providers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..87521867e81dd4cb73ad6661733a7d0315530c21 --- /dev/null +++ b/src/agent/providers/base.py @@ -0,0 +1,105 @@ +""" +LLM Provider Base Class + +Abstract base class for LLM provider implementations. +Defines the interface for generating responses with function calling support. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + + +@dataclass +class LLMResponse: + """Response from an LLM provider.""" + content: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + finish_reason: Optional[str] = None + usage: Optional[Dict[str, int]] = None + + +class LLMProvider(ABC): + """ + Abstract base class for LLM providers. + + All provider implementations (Gemini, OpenRouter, Cohere) must + implement these methods to support function calling and tool execution. + """ + + def __init__(self, api_key: str, model: str, temperature: float = 0.7, max_tokens: int = 8192): + """ + Initialize the LLM provider. + + Args: + api_key: API key for the provider + model: Model identifier (e.g., "gemini-1.5-flash") + temperature: Sampling temperature (0.0 to 1.0) + max_tokens: Maximum tokens in response + """ + self.api_key = api_key + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + @abstractmethod + async def generate_response_with_tools( + self, + messages: List[Dict[str, str]], + system_prompt: str, + tools: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a response with function calling support. + + Args: + messages: Conversation history [{"role": "user", "content": "..."}] + system_prompt: System instructions for the agent + tools: Tool definitions for function calling + + Returns: + LLMResponse with content and/or tool_calls + """ + pass + + @abstractmethod + async def generate_response_with_tool_results( + self, + messages: List[Dict[str, str]], + tool_calls: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a final response after tool execution. + + Args: + messages: Original conversation history + tool_calls: Tool calls that were made + tool_results: Results from tool execution + + Returns: + LLMResponse with final content + """ + pass + + @abstractmethod + async def generate_simple_response( + self, + messages: List[Dict[str, str]], + system_prompt: str + ) -> LLMResponse: + """ + Generate a simple response without function calling. + + Args: + messages: Conversation history + system_prompt: System instructions + + Returns: + LLMResponse with content + """ + pass + + def get_provider_name(self) -> str: + """Get the provider name (e.g., 'gemini', 'openrouter', 'cohere').""" + return self.__class__.__name__.replace("Provider", "").lower() diff --git a/src/agent/providers/cohere.py b/src/agent/providers/cohere.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3e4c57a7667c3ab71261014fe44d29e98eb478 --- /dev/null +++ b/src/agent/providers/cohere.py @@ -0,0 +1,232 @@ +""" +Cohere Provider Implementation + +Cohere API provider with function calling support. +Optional provider (trial only, not recommended for production). +""" + +import logging +from typing import List, Dict, Any +import cohere + +from .base import LLMProvider, LLMResponse + +logger = logging.getLogger(__name__) + + +class CohereProvider(LLMProvider): + """ + Cohere API provider implementation. + + Features: + - Native function calling support + - Trial tier only (not recommended for production) + - Model: command-r-plus (best for function calling) + + Note: Cohere requires a paid plan after trial expires. + Use Gemini or OpenRouter for free-tier operation. + """ + + def __init__( + self, + api_key: str, + model: str = "command-r-plus", + temperature: float = 0.7, + max_tokens: int = 8192 + ): + super().__init__(api_key, model, temperature, max_tokens) + self.client = cohere.Client(api_key) + logger.info(f"Initialized CohereProvider with model: {model}") + + def _convert_tools_to_cohere_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert MCP tool definitions to Cohere tool format. + + Args: + tools: MCP tool definitions + + Returns: + List of Cohere-formatted tool definitions + """ + return [ + { + "name": tool["name"], + "description": tool["description"], + "parameter_definitions": tool["parameters"].get("properties", {}) + } + for tool in tools + ] + + async def generate_response_with_tools( + self, + messages: List[Dict[str, str]], + system_prompt: str, + tools: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a response with function calling support. + + Args: + messages: Conversation history + system_prompt: System instructions + tools: Tool definitions + + Returns: + LLMResponse with content and/or tool_calls + """ + try: + # Convert tools to Cohere format + cohere_tools = self._convert_tools_to_cohere_format(tools) + + # Format chat history for Cohere + chat_history = [] + for msg in messages[:-1]: # All except last message + chat_history.append({ + "role": "USER" if msg["role"] == "user" else "CHATBOT", + "message": msg["content"] + }) + + # Last message is the current user message + current_message = messages[-1]["content"] if messages else "" + + # Generate response with function calling + response = self.client.chat( + message=current_message, + chat_history=chat_history, + preamble=system_prompt, + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + tools=cohere_tools + ) + + # Check for tool calls + if response.tool_calls: + tool_calls = [ + { + "name": tc.name, + "arguments": tc.parameters + } + for tc in response.tool_calls + ] + logger.info(f"Cohere requested function calls: {[tc['name'] for tc in tool_calls]}") + return LLMResponse( + content=None, + tool_calls=tool_calls, + finish_reason="tool_calls" + ) + + # Regular text response + content = response.text + logger.info("Cohere generated text response") + return LLMResponse( + content=content, + finish_reason="COMPLETE" + ) + + except Exception as e: + logger.error(f"Cohere API error: {str(e)}") + raise + + async def generate_response_with_tool_results( + self, + messages: List[Dict[str, str]], + tool_calls: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a final response after tool execution. + + Args: + messages: Original conversation history + tool_calls: Tool calls that were made + tool_results: Results from tool execution + + Returns: + LLMResponse with final content + """ + try: + # Format chat history + chat_history = [] + for msg in messages: + chat_history.append({ + "role": "USER" if msg["role"] == "user" else "CHATBOT", + "message": msg["content"] + }) + + # Format tool results for Cohere + tool_results_formatted = [ + { + "call": {"name": call["name"], "parameters": call["arguments"]}, + "outputs": [{"result": str(result)}] + } + for call, result in zip(tool_calls, tool_results) + ] + + # Generate final response + response = self.client.chat( + message="Based on the tool results, provide a natural language response.", + chat_history=chat_history, + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + tool_results=tool_results_formatted + ) + + content = response.text + logger.info("Cohere generated final response after tool execution") + return LLMResponse( + content=content, + finish_reason="COMPLETE" + ) + + except Exception as e: + logger.error(f"Cohere API error in tool results: {str(e)}") + raise + + async def generate_simple_response( + self, + messages: List[Dict[str, str]], + system_prompt: str + ) -> LLMResponse: + """ + Generate a simple response without function calling. + + Args: + messages: Conversation history + system_prompt: System instructions + + Returns: + LLMResponse with content + """ + try: + # Format chat history + chat_history = [] + for msg in messages[:-1]: + chat_history.append({ + "role": "USER" if msg["role"] == "user" else "CHATBOT", + "message": msg["content"] + }) + + current_message = messages[-1]["content"] if messages else "" + + # Generate response + response = self.client.chat( + message=current_message, + chat_history=chat_history, + preamble=system_prompt, + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens + ) + + content = response.text + logger.info("Cohere generated simple response") + return LLMResponse( + content=content, + finish_reason="COMPLETE" + ) + + except Exception as e: + logger.error(f"Cohere API error: {str(e)}") + raise diff --git a/src/agent/providers/gemini.py b/src/agent/providers/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4377ed18dc1d3f6ae897ec4467e6ce1ae8d397 --- /dev/null +++ b/src/agent/providers/gemini.py @@ -0,0 +1,285 @@ +""" +Gemini Provider Implementation + +Google Gemini API provider with function calling support. +Primary provider for free-tier operation (15 RPM, 1M token context). +""" + +import logging +from typing import List, Dict, Any +import google.generativeai as genai +from google.generativeai.types import FunctionDeclaration, Tool + +from .base import LLMProvider, LLMResponse + +logger = logging.getLogger(__name__) + + +class GeminiProvider(LLMProvider): + """ + Google Gemini API provider implementation. + + Features: + - Native function calling support + - 1M token context window + - Free tier: 15 requests/minute + - Model: gemini-1.5-flash (recommended for free tier) + """ + + def __init__(self, api_key: str, model: str = "gemini-flash-latest", temperature: float = 0.7, max_tokens: int = 8192): + super().__init__(api_key, model, temperature, max_tokens) + genai.configure(api_key=api_key) + self.client = genai.GenerativeModel(model) + logger.info(f"Initialized GeminiProvider with model: {model}") + + def _sanitize_schema_for_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitize JSON Schema to be Gemini-compatible. + + Gemini only supports a subset of JSON Schema keywords: + - Supported: type, description, enum, required, properties, items + - NOT supported: maxLength, minLength, pattern, format, minimum, maximum, default, etc. + + Args: + schema: Original JSON Schema + + Returns: + Gemini-compatible schema with unsupported fields removed + """ + # Fields that Gemini supports + ALLOWED_FIELDS = { + "type", "description", "enum", "required", + "properties", "items" + } + + # Create a sanitized copy + sanitized = {} + + for key, value in schema.items(): + if key in ALLOWED_FIELDS: + # Recursively sanitize nested objects + if key == "properties" and isinstance(value, dict): + sanitized[key] = { + prop_name: self._sanitize_schema_for_gemini(prop_schema) + for prop_name, prop_schema in value.items() + } + elif key == "items" and isinstance(value, dict): + sanitized[key] = self._sanitize_schema_for_gemini(value) + else: + sanitized[key] = value + + return sanitized + + def _convert_tools_to_gemini_format(self, tools: List[Dict[str, Any]]) -> List[Tool]: + """ + Convert MCP tool definitions to Gemini function declarations. + + Sanitizes schemas to remove unsupported JSON Schema keywords. + + Args: + tools: MCP tool definitions + + Returns: + List of Gemini Tool objects + """ + function_declarations = [] + for tool in tools: + # Sanitize parameters to remove unsupported fields + sanitized_parameters = self._sanitize_schema_for_gemini(tool["parameters"]) + + function_declarations.append( + FunctionDeclaration( + name=tool["name"], + description=tool["description"], + parameters=sanitized_parameters + ) + ) + + logger.debug(f"Sanitized tool schema for Gemini: {tool['name']}") + + return [Tool(function_declarations=function_declarations)] + + def _convert_messages_to_gemini_format(self, messages: List[Dict[str, str]], system_prompt: str) -> List[Dict[str, str]]: + """ + Convert standard message format to Gemini format. + + Args: + messages: Standard message format [{"role": "user", "content": "..."}] + system_prompt: System instructions + + Returns: + Gemini-formatted messages + """ + gemini_messages = [] + + # Add system prompt as first user message if provided + if system_prompt: + gemini_messages.append({ + "role": "user", + "parts": [{"text": system_prompt}] + }) + gemini_messages.append({ + "role": "model", + "parts": [{"text": "Understood. I'll follow these instructions."}] + }) + + # Convert messages + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + gemini_messages.append({ + "role": role, + "parts": [{"text": msg["content"]}] + }) + + return gemini_messages + + async def generate_response_with_tools( + self, + messages: List[Dict[str, str]], + system_prompt: str, + tools: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a response with function calling support. + + Args: + messages: Conversation history + system_prompt: System instructions + tools: Tool definitions + + Returns: + LLMResponse with content and/or tool_calls + """ + try: + # Convert tools to Gemini format + gemini_tools = self._convert_tools_to_gemini_format(tools) + + # Convert messages to Gemini format + gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt) + + # Generate response with function calling + response = self.client.generate_content( + gemini_messages, + tools=gemini_tools, + generation_config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens + } + ) + + # Check if function calls were made + if response.candidates[0].content.parts: + first_part = response.candidates[0].content.parts[0] + + # Check for function call + if hasattr(first_part, 'function_call') and first_part.function_call: + function_call = first_part.function_call + tool_calls = [{ + "name": function_call.name, + "arguments": dict(function_call.args) + }] + logger.info(f"Gemini requested function call: {function_call.name}") + return LLMResponse( + content=None, + tool_calls=tool_calls, + finish_reason="function_call" + ) + + # Regular text response + content = response.text if hasattr(response, 'text') else None + logger.info("Gemini generated text response") + return LLMResponse( + content=content, + finish_reason="stop" + ) + + except Exception as e: + logger.error(f"Gemini API error: {str(e)}") + raise + + async def generate_response_with_tool_results( + self, + messages: List[Dict[str, str]], + tool_calls: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a final response after tool execution. + + Args: + messages: Original conversation history + tool_calls: Tool calls that were made + tool_results: Results from tool execution + + Returns: + LLMResponse with final content + """ + try: + # Format tool results as a message + tool_results_text = "\n\n".join([ + f"Tool: {call['name']}\nResult: {result}" + for call, result in zip(tool_calls, tool_results) + ]) + + # Add tool results to messages + messages_with_results = messages + [ + {"role": "assistant", "content": f"I called the following tools:\n{tool_results_text}"}, + {"role": "user", "content": "Based on these tool results, provide a natural language response to the user."} + ] + + # Generate final response + gemini_messages = self._convert_messages_to_gemini_format(messages_with_results, "") + response = self.client.generate_content( + gemini_messages, + generation_config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens + } + ) + + content = response.text if hasattr(response, 'text') else None + logger.info("Gemini generated final response after tool execution") + return LLMResponse( + content=content, + finish_reason="stop" + ) + + except Exception as e: + logger.error(f"Gemini API error in tool results: {str(e)}") + raise + + async def generate_simple_response( + self, + messages: List[Dict[str, str]], + system_prompt: str + ) -> LLMResponse: + """ + Generate a simple response without function calling. + + Args: + messages: Conversation history + system_prompt: System instructions + + Returns: + LLMResponse with content + """ + try: + gemini_messages = self._convert_messages_to_gemini_format(messages, system_prompt) + response = self.client.generate_content( + gemini_messages, + generation_config={ + "temperature": self.temperature, + "max_output_tokens": self.max_tokens + } + ) + + content = response.text if hasattr(response, 'text') else None + logger.info("Gemini generated simple response") + return LLMResponse( + content=content, + finish_reason="stop" + ) + + except Exception as e: + logger.error(f"Gemini API error: {str(e)}") + raise diff --git a/src/agent/providers/openrouter.py b/src/agent/providers/openrouter.py new file mode 100644 index 0000000000000000000000000000000000000000..73366acd5adfe260d10e47457250ad02e1e2f96d --- /dev/null +++ b/src/agent/providers/openrouter.py @@ -0,0 +1,264 @@ +""" +OpenRouter Provider Implementation + +OpenRouter API provider with function calling support. +Fallback provider for when Gemini rate limits are exceeded. +Uses OpenAI-compatible API format. +""" + +import logging +from typing import List, Dict, Any +import httpx + +from .base import LLMProvider, LLMResponse + +logger = logging.getLogger(__name__) + + +class OpenRouterProvider(LLMProvider): + """ + OpenRouter API provider implementation. + + Features: + - OpenAI-compatible API + - Access to multiple free models + - Function calling support + - Recommended free model: google/gemini-flash-1.5 + """ + + def __init__( + self, + api_key: str, + model: str = "google/gemini-flash-1.5", + temperature: float = 0.7, + max_tokens: int = 8192 + ): + super().__init__(api_key, model, temperature, max_tokens) + self.base_url = "https://openrouter.ai/api/v1" + self.headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + logger.info(f"Initialized OpenRouterProvider with model: {model}") + + def _convert_tools_to_openai_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert MCP tool definitions to OpenAI function format. + + Args: + tools: MCP tool definitions + + Returns: + List of OpenAI-formatted function definitions + """ + return [ + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"] + } + } + for tool in tools + ] + + async def generate_response_with_tools( + self, + messages: List[Dict[str, str]], + system_prompt: str, + tools: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a response with function calling support. + + Args: + messages: Conversation history + system_prompt: System instructions + tools: Tool definitions + + Returns: + LLMResponse with content and/or tool_calls + """ + try: + # Prepare messages with system prompt + formatted_messages = [{"role": "system", "content": system_prompt}] + messages + + # Convert tools to OpenAI format + openai_tools = self._convert_tools_to_openai_format(tools) + + # Make API request + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{self.base_url}/chat/completions", + headers=self.headers, + json={ + "model": self.model, + "messages": formatted_messages, + "tools": openai_tools, + "temperature": self.temperature, + "max_tokens": self.max_tokens + } + ) + response.raise_for_status() + data = response.json() + + # Parse response + choice = data["choices"][0] + message = choice["message"] + + # Check for function calls + if "tool_calls" in message and message["tool_calls"]: + tool_calls = [ + { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"] + } + for tc in message["tool_calls"] + ] + logger.info(f"OpenRouter requested function calls: {[tc['name'] for tc in tool_calls]}") + return LLMResponse( + content=None, + tool_calls=tool_calls, + finish_reason=choice.get("finish_reason", "function_call") + ) + + # Regular text response + content = message.get("content") + logger.info("OpenRouter generated text response") + return LLMResponse( + content=content, + finish_reason=choice.get("finish_reason", "stop") + ) + + except httpx.HTTPStatusError as e: + logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"OpenRouter API error: {str(e)}") + raise + + async def generate_response_with_tool_results( + self, + messages: List[Dict[str, str]], + tool_calls: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]] + ) -> LLMResponse: + """ + Generate a final response after tool execution. + + Args: + messages: Original conversation history + tool_calls: Tool calls that were made + tool_results: Results from tool execution + + Returns: + LLMResponse with final content + """ + try: + # Format tool results as messages + messages_with_results = messages.copy() + + # Add assistant message with tool calls + messages_with_results.append({ + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": f"call_{i}", + "type": "function", + "function": { + "name": call["name"], + "arguments": str(call["arguments"]) + } + } + for i, call in enumerate(tool_calls) + ] + }) + + # Add tool result messages + for i, (call, result) in enumerate(zip(tool_calls, tool_results)): + messages_with_results.append({ + "role": "tool", + "tool_call_id": f"call_{i}", + "content": str(result) + }) + + # Generate final response + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{self.base_url}/chat/completions", + headers=self.headers, + json={ + "model": self.model, + "messages": messages_with_results, + "temperature": self.temperature, + "max_tokens": self.max_tokens + } + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + content = choice["message"].get("content") + logger.info("OpenRouter generated final response after tool execution") + return LLMResponse( + content=content, + finish_reason=choice.get("finish_reason", "stop") + ) + + except httpx.HTTPStatusError as e: + logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"OpenRouter API error in tool results: {str(e)}") + raise + + async def generate_simple_response( + self, + messages: List[Dict[str, str]], + system_prompt: str + ) -> LLMResponse: + """ + Generate a simple response without function calling. + + Args: + messages: Conversation history + system_prompt: System instructions + + Returns: + LLMResponse with content + """ + try: + # Prepare messages with system prompt + formatted_messages = [{"role": "system", "content": system_prompt}] + messages + + # Make API request + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{self.base_url}/chat/completions", + headers=self.headers, + json={ + "model": self.model, + "messages": formatted_messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens + } + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + content = choice["message"].get("content") + logger.info("OpenRouter generated simple response") + return LLMResponse( + content=content, + finish_reason=choice.get("finish_reason", "stop") + ) + + except httpx.HTTPStatusError as e: + logger.error(f"OpenRouter API HTTP error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"OpenRouter API error: {str(e)}") + raise diff --git a/src/api/routes/__pycache__/chat.cpython-313.pyc b/src/api/routes/__pycache__/chat.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd072cde6e3b0f7719f2cc05f53ee72bcb4e013b Binary files /dev/null and b/src/api/routes/__pycache__/chat.cpython-313.pyc differ diff --git a/src/api/routes/__pycache__/conversations.cpython-313.pyc b/src/api/routes/__pycache__/conversations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77246c7da2769a5ba4f24f4b1aebafbab3d8fe59 Binary files /dev/null and b/src/api/routes/__pycache__/conversations.cpython-313.pyc differ diff --git a/src/api/routes/chat.py b/src/api/routes/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..5e994fb82302794fb4839222d20725b70f6b40ba --- /dev/null +++ b/src/api/routes/chat.py @@ -0,0 +1,281 @@ +"""Chat API endpoint for AI chatbot.""" +from fastapi import APIRouter, Depends, HTTPException, status +from sqlmodel import Session +from typing import Dict, Any +import logging +from datetime import datetime + +from src.core.database import get_session +from src.core.security import get_current_user +from src.core.config import settings +from src.schemas.chat_request import ChatRequest +from src.schemas.chat_response import ChatResponse +from src.services.conversation_service import ConversationService +from src.agent.agent_config import AgentConfiguration +from src.agent.agent_runner import AgentRunner +from src.mcp import tool_registry +from src.core.exceptions import ( + classify_ai_error, + APIKeyMissingException, + APIKeyInvalidException +) + + +# Configure logging +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api", tags=["chat"]) + + +def generate_conversation_title(first_user_message: str) -> str: + """Generate a conversation title from the first user message. + + Args: + first_user_message: The first message from the user + + Returns: + A title string (max 50 characters) + """ + # Remove leading/trailing whitespace + message = first_user_message.strip() + + # Try to extract the first sentence or first 50 characters + # Split by common sentence endings + for delimiter in ['. ', '! ', '? ', '\n']: + if delimiter in message: + title = message.split(delimiter)[0] + break + else: + # No sentence delimiter found, use first 50 chars + title = message[:50] + + # If title is too short (less than 10 chars), use timestamp-based default + if len(title) < 10: + return f"Chat {datetime.now().strftime('%b %d, %I:%M %p')}" + + # Truncate to 50 characters and add ellipsis if needed + if len(title) > 50: + title = title[:47] + "..." + + return title + + +@router.post("/{user_id}/chat", response_model=ChatResponse) +async def chat( + user_id: int, + request: ChatRequest, + db: Session = Depends(get_session), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> ChatResponse: + """Handle chat messages from users. + + Args: + user_id: ID of the user sending the message + request: ChatRequest containing the user's message + db: Database session + current_user: Authenticated user from JWT token + + Returns: + ChatResponse containing the AI's response + + Raises: + HTTPException 401: If user is not authenticated or user_id doesn't match + HTTPException 404: If conversation_id is provided but not found + HTTPException 500: If AI provider fails to generate response + """ + # Verify user authorization + if current_user["id"] != user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to access this user's chat" + ) + + try: + # Validate request message length + if not request.message or len(request.message.strip()) == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Message cannot be empty" + ) + + if len(request.message) > 10000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Message exceeds maximum length of 10,000 characters" + ) + + # Initialize services + conversation_service = ConversationService(db) + + # Initialize agent configuration from settings + try: + agent_config = AgentConfiguration( + provider=settings.LLM_PROVIDER, + fallback_provider=settings.FALLBACK_PROVIDER, + gemini_api_key=settings.GEMINI_API_KEY, + openrouter_api_key=settings.OPENROUTER_API_KEY, + cohere_api_key=settings.COHERE_API_KEY, + temperature=settings.AGENT_TEMPERATURE, + max_tokens=settings.AGENT_MAX_TOKENS, + max_messages=settings.CONVERSATION_MAX_MESSAGES, + max_conversation_tokens=settings.CONVERSATION_MAX_TOKENS + ) + agent_config.validate() + + # Create agent runner with tool registry + agent_runner = AgentRunner(agent_config, tool_registry) + except ValueError as e: + logger.error(f"Agent initialization failed: {str(e)}") + # Check if it's an API key issue + error_msg = str(e).lower() + if "api key" in error_msg: + if "not found" in error_msg or "missing" in error_msg: + raise APIKeyMissingException(provider=settings.LLM_PROVIDER) + elif "invalid" in error_msg: + raise APIKeyInvalidException(provider=settings.LLM_PROVIDER) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="AI service is not properly configured. Please contact support." + ) + + # Get or create conversation + is_new_conversation = False + if request.conversation_id: + conversation = conversation_service.get_conversation( + request.conversation_id, + user_id + ) + if not conversation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {request.conversation_id} not found or you don't have access to it" + ) + else: + # Create new conversation with auto-generated title + try: + # Generate title from first user message + title = generate_conversation_title(request.message) + conversation = conversation_service.create_conversation( + user_id=user_id, + title=title + ) + is_new_conversation = True + logger.info(f"Created new conversation {conversation.id} with title: {title}") + except Exception as e: + logger.error(f"Failed to create conversation: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create conversation. Please try again." + ) + + # Add user message to conversation + try: + user_message = conversation_service.add_message( + conversation_id=conversation.id, + role="user", + content=request.message, + token_count=len(request.message) // 4 # Rough token estimate + ) + except Exception as e: + logger.error(f"Failed to save user message: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to save your message. Please try again." + ) + + # Get conversation history and format for agent + history_messages = conversation_service.get_conversation_messages( + conversation_id=conversation.id + ) + + # Format messages for agent with trimming + formatted_messages = conversation_service.format_messages_for_agent( + messages=history_messages, + max_messages=agent_config.max_messages, + max_tokens=agent_config.max_conversation_tokens + ) + + # Generate AI response with tool calling support + system_prompt = request.system_prompt or agent_config.system_prompt + + try: + agent_result = await agent_runner.execute( + messages=formatted_messages, + user_id=user_id, # Inject user context for security + system_prompt=system_prompt + ) + except Exception as e: + # Use classify_ai_error to determine the appropriate exception + logger.error(f"AI service error for user {user_id}: {str(e)}") + provider = agent_result.get("provider") if 'agent_result' in locals() else settings.LLM_PROVIDER + raise classify_ai_error(e, provider=provider) + + # Add AI response to conversation with tool call metadata + try: + # Prepare metadata if tools were used + tool_metadata = None + if agent_result.get("tool_calls"): + # Convert ToolExecutionResult objects to dicts for JSON serialization + tool_results = agent_result.get("tool_results", []) + serializable_results = [] + for result in tool_results: + if hasattr(result, '__dict__'): + # Convert dataclass/object to dict + serializable_results.append({ + "success": result.success, + "data": result.data, + "message": result.message, + "error": result.error + }) + else: + # Already a dict + serializable_results.append(result) + + tool_metadata = { + "tool_calls": agent_result["tool_calls"], + "tool_results": serializable_results, + "provider": agent_result.get("provider") + } + + assistant_message = conversation_service.add_message( + conversation_id=conversation.id, + role="assistant", + content=agent_result["content"], + token_count=len(agent_result["content"]) // 4 # Rough token estimate + ) + + # Update tool_metadata if tools were used + if tool_metadata: + assistant_message.tool_metadata = tool_metadata + db.add(assistant_message) + db.commit() + except Exception as e: + logger.error(f"Failed to save AI response: {str(e)}") + # Still return the response even if saving fails + # User gets the response but it won't be in history + logger.warning(f"Returning response without saving to database for conversation {conversation.id}") + + # Log tool usage if any + if agent_result.get("tool_calls"): + logger.info(f"Agent used {len(agent_result['tool_calls'])} tools for user {user_id}") + + # Return response + return ChatResponse( + conversation_id=conversation.id, + message=agent_result["content"], + role="assistant", + timestamp=assistant_message.timestamp if 'assistant_message' in locals() else user_message.timestamp, + token_count=len(agent_result["content"]) // 4, + model=agent_result.get("provider") + ) + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + # Catch-all for unexpected errors + logger.exception(f"Unexpected error in chat endpoint: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred. Please try again later." + ) diff --git a/src/api/routes/conversations.py b/src/api/routes/conversations.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7905025a9d27168dc88199d47cb9accd2ad337 --- /dev/null +++ b/src/api/routes/conversations.py @@ -0,0 +1,291 @@ +"""Conversations API endpoints for managing chat conversations.""" +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlmodel import Session +from typing import Dict, Any, List +import logging + +from src.core.database import get_session +from src.core.security import get_current_user +from src.services.conversation_service import ConversationService +from src.schemas.conversation import ( + ConversationListResponse, + ConversationSummary, + MessageListResponse, + MessageResponse, + UpdateConversationRequest, + UpdateConversationResponse +) + +# Configure logging +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api", tags=["conversations"]) + + +@router.get("/{user_id}/conversations", response_model=ConversationListResponse) +async def list_conversations( + user_id: int, + limit: int = Query(50, ge=1, le=100, description="Maximum number of conversations to return"), + db: Session = Depends(get_session), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> ConversationListResponse: + """List all conversations for a user. + + Args: + user_id: ID of the user + limit: Maximum number of conversations to return (default: 50, max: 100) + db: Database session + current_user: Authenticated user from JWT token + + Returns: + ConversationListResponse with list of conversations + + Raises: + HTTPException 401: If user is not authenticated or user_id doesn't match + """ + # Verify user authorization + if current_user["id"] != user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to access this user's conversations" + ) + + try: + conversation_service = ConversationService(db) + conversations = conversation_service.get_user_conversations(user_id, limit=limit) + + # Build conversation summaries with message count and preview + summaries: List[ConversationSummary] = [] + for conv in conversations: + # Get messages for this conversation + messages = conversation_service.get_conversation_messages(conv.id) + message_count = len(messages) + + # Get last message preview + last_message_preview = None + if messages: + last_msg = messages[-1] + # Take first 100 characters of the last message + last_message_preview = last_msg.content[:100] + if len(last_msg.content) > 100: + last_message_preview += "..." + + summaries.append(ConversationSummary( + id=conv.id, + title=conv.title, + created_at=conv.created_at, + updated_at=conv.updated_at, + message_count=message_count, + last_message_preview=last_message_preview + )) + + return ConversationListResponse( + conversations=summaries, + total=len(summaries) + ) + + except Exception as e: + logger.exception(f"Failed to list conversations for user {user_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve conversations. Please try again." + ) + + +@router.get("/{user_id}/conversations/{conversation_id}/messages", response_model=MessageListResponse) +async def get_conversation_messages( + user_id: int, + conversation_id: int, + offset: int = Query(0, ge=0, description="Number of messages to skip"), + limit: int = Query(50, ge=1, le=200, description="Maximum number of messages to return"), + db: Session = Depends(get_session), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> MessageListResponse: + """Get message history for a conversation. + + Args: + user_id: ID of the user + conversation_id: ID of the conversation + offset: Number of messages to skip (for pagination) + limit: Maximum number of messages to return (default: 50, max: 200) + db: Database session + current_user: Authenticated user from JWT token + + Returns: + MessageListResponse with list of messages + + Raises: + HTTPException 401: If user is not authenticated or user_id doesn't match + HTTPException 404: If conversation not found or user doesn't have access + """ + # Verify user authorization + if current_user["id"] != user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to access this user's conversations" + ) + + try: + conversation_service = ConversationService(db) + + # Verify conversation exists and belongs to user + conversation = conversation_service.get_conversation(conversation_id, user_id) + if not conversation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found or you don't have access to it" + ) + + # Get all messages (we'll handle pagination manually) + all_messages = conversation_service.get_conversation_messages(conversation_id) + total = len(all_messages) + + # Apply pagination + paginated_messages = all_messages[offset:offset + limit] + + # Convert to response format + message_responses = [ + MessageResponse( + id=msg.id, + role=msg.role, + content=msg.content, + timestamp=msg.timestamp, + token_count=msg.token_count + ) + for msg in paginated_messages + ] + + return MessageListResponse( + conversation_id=conversation_id, + messages=message_responses, + total=total + ) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"Failed to get messages for conversation {conversation_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve messages. Please try again." + ) + + +@router.patch("/{user_id}/conversations/{conversation_id}", response_model=UpdateConversationResponse) +async def update_conversation( + user_id: int, + conversation_id: int, + request: UpdateConversationRequest, + db: Session = Depends(get_session), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> UpdateConversationResponse: + """Update a conversation's title. + + Args: + user_id: ID of the user + conversation_id: ID of the conversation + request: UpdateConversationRequest with new title + db: Database session + current_user: Authenticated user from JWT token + + Returns: + UpdateConversationResponse with updated conversation + + Raises: + HTTPException 401: If user is not authenticated or user_id doesn't match + HTTPException 404: If conversation not found or user doesn't have access + """ + # Verify user authorization + if current_user["id"] != user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to access this user's conversations" + ) + + try: + conversation_service = ConversationService(db) + + # Verify conversation exists and belongs to user + conversation = conversation_service.get_conversation(conversation_id, user_id) + if not conversation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found or you don't have access to it" + ) + + # Update the title + from datetime import datetime + conversation.title = request.title + conversation.updated_at = datetime.utcnow() + db.add(conversation) + db.commit() + db.refresh(conversation) + + return UpdateConversationResponse( + id=conversation.id, + title=conversation.title, + updated_at=conversation.updated_at + ) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"Failed to update conversation {conversation_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update conversation. Please try again." + ) + + +@router.delete("/{user_id}/conversations/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_conversation( + user_id: int, + conversation_id: int, + db: Session = Depends(get_session), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> None: + """Delete a conversation and all its messages. + + Args: + user_id: ID of the user + conversation_id: ID of the conversation + db: Database session + current_user: Authenticated user from JWT token + + Returns: + None (204 No Content) + + Raises: + HTTPException 401: If user is not authenticated or user_id doesn't match + HTTPException 404: If conversation not found or user doesn't have access + """ + # Verify user authorization + if current_user["id"] != user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to access this user's conversations" + ) + + try: + conversation_service = ConversationService(db) + + # Delete the conversation (service method handles authorization check) + deleted = conversation_service.delete_conversation(conversation_id, user_id) + + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Conversation {conversation_id} not found or you don't have access to it" + ) + + # Return 204 No Content (no response body) + return None + + except HTTPException: + raise + except Exception as e: + logger.exception(f"Failed to delete conversation {conversation_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete conversation. Please try again." + ) diff --git a/src/core/__pycache__/config.cpython-313.pyc b/src/core/__pycache__/config.cpython-313.pyc index 8db1e78267e413ced2b36ab7bd7f72625f576e8e..b96c1bd27e16bb59159e2ebeef31297fc5c6e7e4 100644 Binary files a/src/core/__pycache__/config.cpython-313.pyc and b/src/core/__pycache__/config.cpython-313.pyc differ diff --git a/src/core/__pycache__/exceptions.cpython-313.pyc b/src/core/__pycache__/exceptions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395c1076a52aa58b41887c80cf42b041893a68e8 Binary files /dev/null and b/src/core/__pycache__/exceptions.cpython-313.pyc differ diff --git a/src/core/__pycache__/security.cpython-313.pyc b/src/core/__pycache__/security.cpython-313.pyc index 847e16eefeaffd23f1694c8b867439b4852e828a..decb5b35dd8dc6e60f9ce450be3dc52781cf1b47 100644 Binary files a/src/core/__pycache__/security.cpython-313.pyc and b/src/core/__pycache__/security.cpython-313.pyc differ diff --git a/src/core/config.py b/src/core/config.py index 3060382acec65342fddb32f81f3263c7f15e2aeb..4447d8483b6eb457d690da7cefe598969ec46789 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -19,6 +19,21 @@ class Settings(BaseSettings): JWT_ALGORITHM: str = "HS256" JWT_EXPIRATION_DAYS: int = 7 + # LLM Provider Configuration + LLM_PROVIDER: str = "gemini" # Primary provider: gemini, openrouter, cohere + FALLBACK_PROVIDER: str | None = None # Optional fallback provider + GEMINI_API_KEY: str | None = None # Required if using Gemini + OPENROUTER_API_KEY: str | None = None # Required if using OpenRouter + COHERE_API_KEY: str | None = None # Required if using Cohere + + # Agent Configuration + AGENT_TEMPERATURE: float = 0.7 # Sampling temperature (0.0 to 1.0) + AGENT_MAX_TOKENS: int = 8192 # Maximum tokens in response + + # Conversation Settings (for free-tier constraints) + CONVERSATION_MAX_MESSAGES: int = 20 # Maximum messages to keep in history + CONVERSATION_MAX_TOKENS: int = 8000 # Maximum tokens in conversation history + class Config: env_file = ".env" case_sensitive = True diff --git a/src/core/exceptions.py b/src/core/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ba57655a303a3c880b3592f82a2dc93c112fc8 --- /dev/null +++ b/src/core/exceptions.py @@ -0,0 +1,125 @@ +""" +Custom exception classes for structured error handling. +""" + +from typing import Optional +from fastapi import HTTPException, status +from src.schemas.error import ErrorCode + + +class AIProviderException(HTTPException): + """Base exception for AI provider errors.""" + + def __init__( + self, + error_code: str, + detail: str, + status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, + provider: Optional[str] = None + ): + super().__init__(status_code=status_code, detail=detail) + self.error_code = error_code + self.source = "AI_PROVIDER" + self.provider = provider + + +class RateLimitExceededException(AIProviderException): + """Exception raised when AI provider rate limit is exceeded.""" + + def __init__(self, provider: Optional[str] = None): + super().__init__( + error_code=ErrorCode.RATE_LIMIT_EXCEEDED, + detail="AI service rate limit exceeded. Please wait a moment and try again.", + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + provider=provider + ) + + +class APIKeyMissingException(AIProviderException): + """Exception raised when API key is not configured.""" + + def __init__(self, provider: Optional[str] = None): + super().__init__( + error_code=ErrorCode.API_KEY_MISSING, + detail="AI service is not configured. Please add an API key.", + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + provider=provider + ) + + +class APIKeyInvalidException(AIProviderException): + """Exception raised when API key is invalid or expired.""" + + def __init__(self, provider: Optional[str] = None): + super().__init__( + error_code=ErrorCode.API_KEY_INVALID, + detail="Your API key is invalid or expired. Please check your configuration.", + status_code=status.HTTP_401_UNAUTHORIZED, + provider=provider + ) + + +class ProviderUnavailableException(AIProviderException): + """Exception raised when AI provider is temporarily unavailable.""" + + def __init__(self, provider: Optional[str] = None): + super().__init__( + error_code=ErrorCode.PROVIDER_UNAVAILABLE, + detail="AI service is temporarily unavailable. Please try again later.", + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + provider=provider + ) + + +class ProviderErrorException(AIProviderException): + """Exception raised for generic AI provider errors.""" + + def __init__(self, detail: str, provider: Optional[str] = None): + super().__init__( + error_code=ErrorCode.PROVIDER_ERROR, + detail=detail, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + provider=provider + ) + + +def classify_ai_error(error: Exception, provider: Optional[str] = None) -> AIProviderException: + """ + Classify an AI provider error and return appropriate exception. + + Args: + error: The original exception from the AI provider + provider: Name of the AI provider (gemini, openrouter, cohere) + + Returns: + Appropriate AIProviderException subclass + """ + error_message = str(error).lower() + + # Rate limit errors + if any(keyword in error_message for keyword in ["rate limit", "429", "quota exceeded", "too many requests"]): + return RateLimitExceededException(provider=provider) + + # API key missing errors + if any(keyword in error_message for keyword in ["api key not found", "api key is required", "missing api key"]): + return APIKeyMissingException(provider=provider) + + # API key invalid errors + if any(keyword in error_message for keyword in [ + "invalid api key", "api key invalid", "unauthorized", "401", + "authentication failed", "invalid credentials", "api key expired" + ]): + return APIKeyInvalidException(provider=provider) + + # Provider unavailable errors + if any(keyword in error_message for keyword in [ + "503", "service unavailable", "temporarily unavailable", + "connection refused", "connection timeout", "timeout" + ]): + return ProviderUnavailableException(provider=provider) + + # Generic provider error + return ProviderErrorException( + detail=f"AI service error: {str(error)}", + provider=provider + ) diff --git a/src/core/security.py b/src/core/security.py index 33154ed162bec434afce8bcbc26fe6f54cc03e6b..2ddb5c35a783633b6d52d7cc936ce89c038aa55f 100644 --- a/src/core/security.py +++ b/src/core/security.py @@ -111,9 +111,13 @@ import jwt from datetime import datetime, timedelta from passlib.context import CryptContext -from fastapi import HTTPException, status +from fastapi import HTTPException, status, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from typing import Dict, Any +from src.core.config import settings pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +security = HTTPBearer() import hashlib MAX_BCRYPT_BYTES = 72 @@ -165,3 +169,50 @@ def verify_jwt_token(token: str, secret: str) -> dict: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") + + +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security) +) -> Dict[str, Any]: + """ + FastAPI dependency to extract and validate JWT token from Authorization header. + + Args: + credentials: HTTP Bearer token credentials from request header + + Returns: + Dictionary containing user information from token payload: + - id: User ID (parsed from 'sub' claim) + - email: User email + - iat: Token issued at timestamp + - exp: Token expiration timestamp + + Raises: + HTTPException 401: If token is missing, invalid, or expired + """ + token = credentials.credentials + + try: + payload = verify_jwt_token(token, settings.BETTER_AUTH_SECRET) + + # Extract user ID from 'sub' claim and convert to integer + user_id = int(payload.get("sub")) + + return { + "id": user_id, + "email": payload.get("email"), + "iat": payload.get("iat"), + "exp": payload.get("exp") + } + except ValueError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid user ID in token", + headers={"WWW-Authenticate": "Bearer"} + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Authentication failed: {str(e)}", + headers={"WWW-Authenticate": "Bearer"} + ) diff --git a/src/main.py b/src/main.py index 600e85df145bbf2863a4b52f93fa39ee96826a7b..a6d222dbbab58ff9a4acd5ef287490c0bef6fdc4 100644 --- a/src/main.py +++ b/src/main.py @@ -1,13 +1,76 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +import logging from .core.config import settings -from .api.routes import tasks, auth +from .api.routes import tasks, auth, chat, conversations +from .mcp import register_all_tools +from .core.exceptions import AIProviderException +from .schemas.error import ErrorResponse + +# Configure logging +logger = logging.getLogger(__name__) app = FastAPI( title=settings.APP_NAME, debug=settings.DEBUG ) + +# Global exception handler for AIProviderException +@app.exception_handler(AIProviderException) +async def ai_provider_exception_handler(request: Request, exc: AIProviderException): + """Convert AIProviderException to structured ErrorResponse.""" + error_response = ErrorResponse( + error_code=exc.error_code, + detail=exc.detail, + source=exc.source, + provider=exc.provider + ) + logger.error( + f"AI Provider Error: {exc.error_code} - {exc.detail} " + f"(Provider: {exc.provider}, Status: {exc.status_code})" + ) + return JSONResponse( + status_code=exc.status_code, + content=error_response.model_dump() + ) + + +# Global exception handler for generic HTTPException +@app.exception_handler(Exception) +async def generic_exception_handler(request: Request, exc: Exception): + """Catch-all exception handler for unexpected errors.""" + # Log the full exception for debugging + logger.exception(f"Unhandled exception: {str(exc)}") + + # Return structured error response + error_response = ErrorResponse( + error_code="INTERNAL_ERROR", + detail="An unexpected error occurred. Please try again later.", + source="INTERNAL" + ) + return JSONResponse( + status_code=500, + content=error_response.model_dump() + ) + + +@app.on_event("startup") +async def startup_event(): + """Initialize application on startup.""" + logger.info("Starting application initialization...") + + # Register all MCP tools with the tool registry + try: + register_all_tools() + logger.info("MCP tools registered successfully") + except Exception as e: + logger.error(f"Failed to register MCP tools: {str(e)}") + raise + + logger.info("Application initialization complete") + # Configure CORS app.add_middleware( CORSMiddleware, @@ -20,6 +83,8 @@ app.add_middleware( # Register routes app.include_router(auth.router) app.include_router(tasks.router) +app.include_router(chat.router) +app.include_router(conversations.router) @app.get("/") diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01028f6d68eedd6e23cc6d42cde8d5830ad9d563 --- /dev/null +++ b/src/mcp/__init__.py @@ -0,0 +1,130 @@ +""" +MCP Tools Registration + +Registers all MCP tools with the global tool registry. +Each tool is registered with its contract definition (name, description, parameters). +""" + +import json +import logging +from pathlib import Path + +from .tool_registry import tool_registry +from .tools.add_task import add_task +from .tools.list_tasks import list_tasks +from .tools.complete_task import complete_task +from .tools.delete_task import delete_task +from .tools.update_task import update_task + +logger = logging.getLogger(__name__) + + +def load_tool_contract(tool_name: str) -> dict: + """ + Load tool contract definition from JSON file. + + Args: + tool_name: Name of the tool (e.g., "add_task") + + Returns: + Tool contract dictionary + + Raises: + FileNotFoundError: If contract file not found + """ + # Get the project root directory + current_file = Path(__file__) + project_root = current_file.parent.parent.parent # backend/src/mcp -> backend + contract_path = project_root.parent / "specs" / "001-openai-agent-mcp-tools" / "contracts" / f"{tool_name}.json" + + if not contract_path.exists(): + raise FileNotFoundError(f"Contract file not found: {contract_path}") + + with open(contract_path, "r") as f: + return json.load(f) + + +def register_all_tools(): + """ + Register all MCP tools with the global tool registry. + + This function should be called during application startup to ensure + all tools are available to the agent. + """ + logger.info("Registering MCP tools...") + + # Register add_task tool + try: + add_task_contract = load_tool_contract("add_task") + tool_registry.register_tool( + name=add_task_contract["name"], + description=add_task_contract["description"], + parameters=add_task_contract["parameters"], + handler=add_task + ) + logger.info("Registered tool: add_task") + except Exception as e: + logger.error(f"Failed to register add_task tool: {str(e)}") + raise + + # Register list_tasks tool + try: + list_tasks_contract = load_tool_contract("list_tasks") + tool_registry.register_tool( + name=list_tasks_contract["name"], + description=list_tasks_contract["description"], + parameters=list_tasks_contract["parameters"], + handler=list_tasks + ) + logger.info("Registered tool: list_tasks") + except Exception as e: + logger.error(f"Failed to register list_tasks tool: {str(e)}") + raise + + # Register complete_task tool + try: + complete_task_contract = load_tool_contract("complete_task") + tool_registry.register_tool( + name=complete_task_contract["name"], + description=complete_task_contract["description"], + parameters=complete_task_contract["parameters"], + handler=complete_task + ) + logger.info("Registered tool: complete_task") + except Exception as e: + logger.error(f"Failed to register complete_task tool: {str(e)}") + raise + + # Register delete_task tool + try: + delete_task_contract = load_tool_contract("delete_task") + tool_registry.register_tool( + name=delete_task_contract["name"], + description=delete_task_contract["description"], + parameters=delete_task_contract["parameters"], + handler=delete_task + ) + logger.info("Registered tool: delete_task") + except Exception as e: + logger.error(f"Failed to register delete_task tool: {str(e)}") + raise + + # Register update_task tool + try: + update_task_contract = load_tool_contract("update_task") + tool_registry.register_tool( + name=update_task_contract["name"], + description=update_task_contract["description"], + parameters=update_task_contract["parameters"], + handler=update_task + ) + logger.info("Registered tool: update_task") + except Exception as e: + logger.error(f"Failed to register update_task tool: {str(e)}") + raise + + logger.info(f"Successfully registered {len(tool_registry.list_tools())} MCP tools") + + +# Export the global registry instance for use in other modules +__all__ = ["tool_registry", "register_all_tools"] diff --git a/src/mcp/__pycache__/__init__.cpython-313.pyc b/src/mcp/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d229b9b80891d267766492d801db9ea625be74e Binary files /dev/null and b/src/mcp/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/mcp/__pycache__/tool_registry.cpython-313.pyc b/src/mcp/__pycache__/tool_registry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1129acdfb839f487bc97f19f6419011f210ed990 Binary files /dev/null and b/src/mcp/__pycache__/tool_registry.cpython-313.pyc differ diff --git a/src/mcp/tool_registry.py b/src/mcp/tool_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0fc61e12738f99a3338b91d82083e7ad515929 --- /dev/null +++ b/src/mcp/tool_registry.py @@ -0,0 +1,140 @@ +""" +MCP Tool Registry + +Manages registration and execution of MCP tools with user context injection. +Security: user_id is injected by the backend, never trusted from LLM output. +""" + +from typing import Dict, List, Any, Callable, Optional +from dataclasses import dataclass +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolDefinition: + """Definition of an MCP tool for LLM function calling.""" + name: str + description: str + parameters: Dict[str, Any] + + +@dataclass +class ToolExecutionResult: + """Result of executing an MCP tool.""" + success: bool + data: Optional[Dict[str, Any]] = None + message: Optional[str] = None + error: Optional[str] = None + + +class MCPToolRegistry: + """ + Registry for MCP tools with user context injection. + + This class manages tool registration and execution, ensuring that + user_id is always injected by the backend for security. + """ + + def __init__(self): + self._tools: Dict[str, Callable] = {} + self._tool_definitions: Dict[str, ToolDefinition] = {} + + def register_tool( + self, + name: str, + description: str, + parameters: Dict[str, Any], + handler: Callable + ) -> None: + """ + Register an MCP tool with its handler function. + + Args: + name: Tool name (e.g., "add_task") + description: Tool description for LLM + parameters: JSON schema for tool parameters + handler: Async function that executes the tool + """ + self._tools[name] = handler + self._tool_definitions[name] = ToolDefinition( + name=name, + description=description, + parameters=parameters + ) + logger.info(f"Registered MCP tool: {name}") + + def get_tool_definitions(self) -> List[Dict[str, Any]]: + """ + Get tool definitions in format suitable for LLM function calling. + + Returns: + List of tool definitions with name, description, and parameters + """ + return [ + { + "name": tool_def.name, + "description": tool_def.description, + "parameters": tool_def.parameters + } + for tool_def in self._tool_definitions.values() + ] + + async def execute_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + user_id: int + ) -> ToolExecutionResult: + """ + Execute an MCP tool with user context injection. + + SECURITY: user_id is injected by the backend, never from LLM output. + + Args: + tool_name: Name of the tool to execute + arguments: Tool arguments from LLM + user_id: User ID (injected by backend, not from LLM) + + Returns: + ToolExecutionResult with success status and data/error + """ + if tool_name not in self._tools: + logger.error(f"Tool not found: {tool_name}") + return ToolExecutionResult( + success=False, + error=f"Tool '{tool_name}' not found" + ) + + try: + # Inject user_id into arguments for security + arguments_with_context = {**arguments, "user_id": user_id} + + logger.info(f"Executing tool: {tool_name} for user: {user_id}") + + # Execute the tool handler + handler = self._tools[tool_name] + result = await handler(**arguments_with_context) + + logger.info(f"Tool execution successful: {tool_name}") + return result + + except Exception as e: + logger.error(f"Tool execution failed: {tool_name} - {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Tool execution failed: {str(e)}" + ) + + def list_tools(self) -> List[str]: + """Get list of registered tool names.""" + return list(self._tools.keys()) + + def has_tool(self, tool_name: str) -> bool: + """Check if a tool is registered.""" + return tool_name in self._tools + + +# Global registry instance +tool_registry = MCPToolRegistry() diff --git a/src/mcp/tools/__init__.py b/src/mcp/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mcp/tools/__pycache__/__init__.cpython-313.pyc b/src/mcp/tools/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96ec5e14ee3d3de5b27a912ba62b59965059de76 Binary files /dev/null and b/src/mcp/tools/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/mcp/tools/__pycache__/add_task.cpython-313.pyc b/src/mcp/tools/__pycache__/add_task.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ea7cea13711e1888f3b433332573e50fb50dedb Binary files /dev/null and b/src/mcp/tools/__pycache__/add_task.cpython-313.pyc differ diff --git a/src/mcp/tools/__pycache__/complete_task.cpython-313.pyc b/src/mcp/tools/__pycache__/complete_task.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73514b222921c303cdea3cc00a9df243699115a Binary files /dev/null and b/src/mcp/tools/__pycache__/complete_task.cpython-313.pyc differ diff --git a/src/mcp/tools/__pycache__/delete_task.cpython-313.pyc b/src/mcp/tools/__pycache__/delete_task.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f466042c4a0fe6f51ff46aa1677bf6a70b7d7dba Binary files /dev/null and b/src/mcp/tools/__pycache__/delete_task.cpython-313.pyc differ diff --git a/src/mcp/tools/__pycache__/list_tasks.cpython-313.pyc b/src/mcp/tools/__pycache__/list_tasks.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a14cf62dc6aa371258dce35c2cdde9e597bb8dab Binary files /dev/null and b/src/mcp/tools/__pycache__/list_tasks.cpython-313.pyc differ diff --git a/src/mcp/tools/__pycache__/update_task.cpython-313.pyc b/src/mcp/tools/__pycache__/update_task.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9afa0e7d80175c14406974658d82889c82113f7 Binary files /dev/null and b/src/mcp/tools/__pycache__/update_task.cpython-313.pyc differ diff --git a/src/mcp/tools/add_task.py b/src/mcp/tools/add_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b21a929e9979700aa2ca368b3c44ec11398dcdfb --- /dev/null +++ b/src/mcp/tools/add_task.py @@ -0,0 +1,131 @@ +""" +Add Task MCP Tool + +MCP tool for creating new tasks via natural language. +Implements user context injection for security. +""" + +import logging +from typing import Optional +from datetime import datetime +from sqlmodel import Session + +from ...models.task import Task +from ...core.database import get_session +from ..tool_registry import ToolExecutionResult + +logger = logging.getLogger(__name__) + + +async def add_task( + title: str, + user_id: int, # Injected by backend, never from LLM + description: Optional[str] = None, + due_date: Optional[str] = None, + priority: Optional[str] = "medium" +) -> ToolExecutionResult: + """ + Create a new task for the user. + + SECURITY: user_id is injected by the backend via MCPToolRegistry. + The LLM cannot specify or modify the user_id. + + Args: + title: Task title (max 200 characters) + user_id: User ID (injected by backend for security) + description: Optional task description (max 1000 characters) + due_date: Optional due date in ISO 8601 format (YYYY-MM-DD) + priority: Task priority (low, medium, high) - default: medium + + Returns: + ToolExecutionResult with success status and task data + """ + try: + # Validate title + if not title or not title.strip(): + logger.warning("add_task called with empty title") + return ToolExecutionResult( + success=False, + error="Task title cannot be empty" + ) + + if len(title) > 200: + logger.warning(f"add_task called with title exceeding 200 characters: {len(title)}") + return ToolExecutionResult( + success=False, + error="Task title cannot exceed 200 characters" + ) + + # Validate description length + if description and len(description) > 1000: + logger.warning(f"add_task called with description exceeding 1000 characters: {len(description)}") + return ToolExecutionResult( + success=False, + error="Task description cannot exceed 1000 characters" + ) + + # Validate priority + valid_priorities = ["low", "medium", "high"] + if priority and priority.lower() not in valid_priorities: + logger.warning(f"add_task called with invalid priority: {priority}") + return ToolExecutionResult( + success=False, + error=f"Priority must be one of: {', '.join(valid_priorities)}" + ) + + # Validate and parse due_date if provided + parsed_due_date = None + if due_date: + try: + parsed_due_date = datetime.fromisoformat(due_date).date() + except ValueError: + logger.warning(f"add_task called with invalid due_date format: {due_date}") + return ToolExecutionResult( + success=False, + error="Due date must be in ISO 8601 format (YYYY-MM-DD)" + ) + + # Create task in database + db: Session = next(get_session()) + try: + task = Task( + user_id=user_id, + title=title.strip(), + description=description.strip() if description else None, + due_date=parsed_due_date, + priority=priority.lower() if priority else "medium", + completed=False, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + + db.add(task) + db.commit() + db.refresh(task) + + logger.info(f"Task created successfully: id={task.id}, user_id={user_id}, title={title}") + + return ToolExecutionResult( + success=True, + data={ + "id": task.id, + "title": task.title, + "description": task.description, + "due_date": task.due_date.isoformat() if task.due_date else None, + "priority": task.priority, + "completed": task.completed, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat() + }, + message=f"Task '{title}' created successfully" + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"Error creating task: {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Failed to create task: {str(e)}" + ) diff --git a/src/mcp/tools/complete_task.py b/src/mcp/tools/complete_task.py new file mode 100644 index 0000000000000000000000000000000000000000..5b84c44c8c0e1c1731925b0fbe52475de2db27f9 --- /dev/null +++ b/src/mcp/tools/complete_task.py @@ -0,0 +1,113 @@ +""" +Complete Task MCP Tool + +MCP tool for marking tasks as completed via natural language. +Supports task identification by ID or title. +Implements user context injection for security. +""" + +import logging +from typing import Union +from datetime import datetime +from sqlmodel import Session, select + +from ...models.task import Task +from ...core.database import get_session +from ..tool_registry import ToolExecutionResult + +logger = logging.getLogger(__name__) + + +async def complete_task( + task_identifier: Union[int, str], + user_id: int # Injected by backend, never from LLM +) -> ToolExecutionResult: + """ + Mark a task as completed. + + SECURITY: user_id is injected by the backend via MCPToolRegistry. + The LLM cannot specify or modify the user_id. + + Args: + task_identifier: Task ID (integer) or task title (string) + user_id: User ID (injected by backend for security) + + Returns: + ToolExecutionResult with success status and updated task data + """ + try: + # Query task from database + db: Session = next(get_session()) + try: + # Build query based on identifier type + if isinstance(task_identifier, int): + # Search by ID + statement = select(Task).where( + Task.id == task_identifier, + Task.user_id == user_id + ) + identifier_type = "ID" + else: + # Search by title (exact match) + statement = select(Task).where( + Task.title == task_identifier, + Task.user_id == user_id + ) + identifier_type = "title" + + task = db.exec(statement).first() + + # Check if task exists + if not task: + logger.warning(f"Task not found: {identifier_type}={task_identifier}, user_id={user_id}") + return ToolExecutionResult( + success=False, + error=f"Task not found with {identifier_type}: {task_identifier}" + ) + + # Check if already completed + if task.completed: + logger.info(f"Task already completed: id={task.id}, user_id={user_id}") + return ToolExecutionResult( + success=True, + data={ + "id": task.id, + "title": task.title, + "description": task.description, + "completed": task.completed, + "updated_at": task.updated_at.isoformat() + }, + message=f"Task '{task.title}' was already marked as completed." + ) + + # Mark task as completed + task.completed = True + task.updated_at = datetime.utcnow() + + db.add(task) + db.commit() + db.refresh(task) + + logger.info(f"Task completed successfully: id={task.id}, user_id={user_id}, title={task.title}") + + return ToolExecutionResult( + success=True, + data={ + "id": task.id, + "title": task.title, + "description": task.description, + "completed": task.completed, + "updated_at": task.updated_at.isoformat() + }, + message=f"Task '{task.title}' marked as completed!" + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"Error completing task: {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Failed to complete task: {str(e)}" + ) diff --git a/src/mcp/tools/delete_task.py b/src/mcp/tools/delete_task.py new file mode 100644 index 0000000000000000000000000000000000000000..248fcb30006f1b6798c81229bd453f06ab300c2c --- /dev/null +++ b/src/mcp/tools/delete_task.py @@ -0,0 +1,89 @@ +""" +Delete Task MCP Tool + +MCP tool for deleting tasks via natural language. +Supports task identification by ID or title. +Implements user context injection for security. +""" + +import logging +from typing import Union +from sqlmodel import Session, select + +from ...models.task import Task +from ...core.database import get_session +from ..tool_registry import ToolExecutionResult + +logger = logging.getLogger(__name__) + + +async def delete_task( + task_identifier: Union[int, str], + user_id: int # Injected by backend, never from LLM +) -> ToolExecutionResult: + """ + Delete a task permanently. + + SECURITY: user_id is injected by the backend via MCPToolRegistry. + The LLM cannot specify or modify the user_id. + + Args: + task_identifier: Task ID (integer) or task title (string) + user_id: User ID (injected by backend for security) + + Returns: + ToolExecutionResult with success status and confirmation message + """ + try: + # Query task from database + db: Session = next(get_session()) + try: + # Build query based on identifier type + if isinstance(task_identifier, int): + # Search by ID + statement = select(Task).where( + Task.id == task_identifier, + Task.user_id == user_id + ) + identifier_type = "ID" + else: + # Search by title (exact match) + statement = select(Task).where( + Task.title == task_identifier, + Task.user_id == user_id + ) + identifier_type = "title" + + task = db.exec(statement).first() + + # Check if task exists + if not task: + logger.warning(f"Task not found for deletion: {identifier_type}={task_identifier}, user_id={user_id}") + return ToolExecutionResult( + success=False, + error=f"Task not found with {identifier_type}: {task_identifier}" + ) + + # Store task title for confirmation message + task_title = task.title + + # Delete task + db.delete(task) + db.commit() + + logger.info(f"Task deleted successfully: id={task.id}, user_id={user_id}, title={task_title}") + + return ToolExecutionResult( + success=True, + message=f"Task '{task_title}' has been deleted successfully." + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"Error deleting task: {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Failed to delete task: {str(e)}" + ) diff --git a/src/mcp/tools/list_tasks.py b/src/mcp/tools/list_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba76724350f3740378dc6414251602bbd73e538 --- /dev/null +++ b/src/mcp/tools/list_tasks.py @@ -0,0 +1,118 @@ +""" +List Tasks MCP Tool + +MCP tool for listing tasks via natural language with filtering support. +Implements user context injection for security. +""" + +import logging +from typing import Optional +from sqlmodel import Session, select + +from ...models.task import Task +from ...core.database import get_session +from ..tool_registry import ToolExecutionResult + +logger = logging.getLogger(__name__) + + +async def list_tasks( + user_id: int, # Injected by backend, never from LLM + filter: Optional[str] = "all" +) -> ToolExecutionResult: + """ + List all tasks for the authenticated user with optional filtering. + + SECURITY: user_id is injected by the backend via MCPToolRegistry. + The LLM cannot specify or modify the user_id. + + Args: + user_id: User ID (injected by backend for security) + filter: Filter by completion status (all, completed, incomplete) - default: all + + Returns: + ToolExecutionResult with success status and tasks data + """ + try: + # Validate filter parameter + valid_filters = ["all", "completed", "incomplete"] + if filter and filter.lower() not in valid_filters: + logger.warning(f"list_tasks called with invalid filter: {filter}") + return ToolExecutionResult( + success=False, + error=f"Filter must be one of: {', '.join(valid_filters)}" + ) + + filter_value = filter.lower() if filter else "all" + + # Query tasks from database + db: Session = next(get_session()) + try: + # Build query based on filter + statement = select(Task).where(Task.user_id == user_id) + + if filter_value == "completed": + statement = statement.where(Task.completed == True) + elif filter_value == "incomplete": + statement = statement.where(Task.completed == False) + # "all" filter doesn't add any additional conditions + + # Order by creation date (newest first) + statement = statement.order_by(Task.created_at.desc()) + + # Execute query + tasks = db.exec(statement).all() + + # Format tasks for response + tasks_data = [ + { + "id": task.id, + "title": task.title, + "description": task.description, + "due_date": task.due_date.isoformat() if task.due_date else None, + "priority": task.priority, + "completed": task.completed, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat() + } + for task in tasks + ] + + # Generate user-friendly message + count = len(tasks_data) + if count == 0: + if filter_value == "completed": + message = "You have no completed tasks." + elif filter_value == "incomplete": + message = "You have no incomplete tasks." + else: + message = "You have no tasks yet. Create one to get started!" + else: + if filter_value == "completed": + message = f"You have {count} completed task{'s' if count != 1 else ''}." + elif filter_value == "incomplete": + message = f"You have {count} incomplete task{'s' if count != 1 else ''}." + else: + message = f"You have {count} task{'s' if count != 1 else ''} in total." + + logger.info(f"Listed {count} tasks for user_id={user_id} with filter={filter_value}") + + return ToolExecutionResult( + success=True, + data={ + "tasks": tasks_data, + "count": count, + "filter": filter_value + }, + message=message + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"Error listing tasks: {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Failed to list tasks: {str(e)}" + ) diff --git a/src/mcp/tools/update_task.py b/src/mcp/tools/update_task.py new file mode 100644 index 0000000000000000000000000000000000000000..527aff1ff0b3e708d93abce7c7d32ec2ea6584d3 --- /dev/null +++ b/src/mcp/tools/update_task.py @@ -0,0 +1,182 @@ +""" +Update Task MCP Tool + +MCP tool for updating task properties via natural language. +Supports task identification by ID or title and updating multiple fields. +Implements user context injection for security. +""" + +import logging +from typing import Union, Optional, Dict, Any +from datetime import datetime +from sqlmodel import Session, select + +from ...models.task import Task +from ...core.database import get_session +from ..tool_registry import ToolExecutionResult + +logger = logging.getLogger(__name__) + + +async def update_task( + task_identifier: Union[int, str], + updates: Dict[str, Any], + user_id: int # Injected by backend, never from LLM +) -> ToolExecutionResult: + """ + Update an existing task's properties. + + SECURITY: user_id is injected by the backend via MCPToolRegistry. + The LLM cannot specify or modify the user_id. + + Args: + task_identifier: Task ID (integer) or task title (string) + updates: Dictionary of fields to update (title, description, due_date, priority, completed) + user_id: User ID (injected by backend for security) + + Returns: + ToolExecutionResult with success status and updated task data + """ + try: + # Validate updates dictionary + if not updates or len(updates) == 0: + logger.warning("update_task called with empty updates dictionary") + return ToolExecutionResult( + success=False, + error="No updates provided. Please specify at least one field to update." + ) + + # Validate allowed fields + allowed_fields = ["title", "description", "due_date", "priority", "completed"] + invalid_fields = [field for field in updates.keys() if field not in allowed_fields] + if invalid_fields: + logger.warning(f"update_task called with invalid fields: {invalid_fields}") + return ToolExecutionResult( + success=False, + error=f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" + ) + + # Query task from database + db: Session = next(get_session()) + try: + # Build query based on identifier type + if isinstance(task_identifier, int): + # Search by ID + statement = select(Task).where( + Task.id == task_identifier, + Task.user_id == user_id + ) + identifier_type = "ID" + else: + # Search by title (exact match) + statement = select(Task).where( + Task.title == task_identifier, + Task.user_id == user_id + ) + identifier_type = "title" + + task = db.exec(statement).first() + + # Check if task exists + if not task: + logger.warning(f"Task not found for update: {identifier_type}={task_identifier}, user_id={user_id}") + return ToolExecutionResult( + success=False, + error=f"Task not found with {identifier_type}: {task_identifier}" + ) + + # Apply updates with validation + updated_fields = [] + + # Update title + if "title" in updates: + new_title = updates["title"] + if not new_title or not new_title.strip(): + return ToolExecutionResult( + success=False, + error="Task title cannot be empty" + ) + if len(new_title) > 200: + return ToolExecutionResult( + success=False, + error="Task title cannot exceed 200 characters" + ) + task.title = new_title.strip() + updated_fields.append("title") + + # Update description + if "description" in updates: + new_description = updates["description"] + if new_description and len(new_description) > 1000: + return ToolExecutionResult( + success=False, + error="Task description cannot exceed 1000 characters" + ) + task.description = new_description.strip() if new_description else None + updated_fields.append("description") + + # Update due_date + if "due_date" in updates: + new_due_date = updates["due_date"] + if new_due_date: + try: + task.due_date = datetime.fromisoformat(new_due_date).date() + except ValueError: + return ToolExecutionResult( + success=False, + error="Due date must be in ISO 8601 format (YYYY-MM-DD)" + ) + else: + task.due_date = None + updated_fields.append("due_date") + + # Update priority + if "priority" in updates: + new_priority = updates["priority"] + valid_priorities = ["low", "medium", "high"] + if new_priority and new_priority.lower() not in valid_priorities: + return ToolExecutionResult( + success=False, + error=f"Priority must be one of: {', '.join(valid_priorities)}" + ) + task.priority = new_priority.lower() if new_priority else "medium" + updated_fields.append("priority") + + # Update completed status + if "completed" in updates: + task.completed = bool(updates["completed"]) + updated_fields.append("completed") + + # Update timestamp + task.updated_at = datetime.utcnow() + + # Save changes + db.add(task) + db.commit() + db.refresh(task) + + logger.info(f"Task updated successfully: id={task.id}, user_id={user_id}, fields={updated_fields}") + + return ToolExecutionResult( + success=True, + data={ + "id": task.id, + "title": task.title, + "description": task.description, + "due_date": task.due_date.isoformat() if task.due_date else None, + "priority": task.priority, + "completed": task.completed, + "updated_at": task.updated_at.isoformat() + }, + message=f"Task '{task.title}' updated successfully. Updated fields: {', '.join(updated_fields)}" + ) + + finally: + db.close() + + except Exception as e: + logger.error(f"Error updating task: {str(e)}") + return ToolExecutionResult( + success=False, + error=f"Failed to update task: {str(e)}" + ) diff --git a/src/models/__pycache__/conversation.cpython-313.pyc b/src/models/__pycache__/conversation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27cfe470439409e84294318e6f16822c739e2e98 Binary files /dev/null and b/src/models/__pycache__/conversation.cpython-313.pyc differ diff --git a/src/models/__pycache__/message.cpython-313.pyc b/src/models/__pycache__/message.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f33be55fbb9cc07277a06e6613a3182cc90221 Binary files /dev/null and b/src/models/__pycache__/message.cpython-313.pyc differ diff --git a/src/models/__pycache__/task.cpython-313.pyc b/src/models/__pycache__/task.cpython-313.pyc index 154fab77373b705ac2780c24bc3247995de84939..c26bc275e7fe0fd4dd790d920f87eb67110bd8fb 100644 Binary files a/src/models/__pycache__/task.cpython-313.pyc and b/src/models/__pycache__/task.cpython-313.pyc differ diff --git a/src/models/conversation.py b/src/models/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..f39a41004e92278c64f04605d097a66d2d11c0f8 --- /dev/null +++ b/src/models/conversation.py @@ -0,0 +1,23 @@ +"""Conversation model for AI chatbot.""" +from datetime import datetime +from typing import Optional, List, TYPE_CHECKING +from sqlmodel import Field, SQLModel, Relationship + +if TYPE_CHECKING: + from .message import Message + from .user import User + + +class Conversation(SQLModel, table=True): + """Conversation model representing a chat session between user and AI.""" + + __tablename__ = "conversation" + + id: Optional[int] = Field(default=None, primary_key=True) + user_id: int = Field(foreign_key="users.id", nullable=False, index=True) + title: Optional[str] = Field(default=None, max_length=255) + created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False, index=True) + updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False) + + # Relationships + messages: List["Message"] = Relationship(back_populates="conversation", cascade_delete=True) diff --git a/src/models/message.py b/src/models/message.py new file mode 100644 index 0000000000000000000000000000000000000000..52132870a186bf347db161b1c71ec651d68cde62 --- /dev/null +++ b/src/models/message.py @@ -0,0 +1,25 @@ +"""Message model for AI chatbot conversations.""" +from datetime import datetime +from typing import Optional, Dict, Any, TYPE_CHECKING +from sqlmodel import Field, SQLModel, Relationship, Column +from sqlalchemy import JSON + +if TYPE_CHECKING: + from .conversation import Conversation + + +class Message(SQLModel, table=True): + """Message model representing a single message in a conversation.""" + + __tablename__ = "message" + + id: Optional[int] = Field(default=None, primary_key=True) + conversation_id: int = Field(foreign_key="conversation.id", nullable=False, index=True) + role: str = Field(max_length=50, nullable=False) # "user" or "assistant" + content: str = Field(nullable=False) + timestamp: datetime = Field(default_factory=datetime.utcnow, nullable=False, index=True) + token_count: Optional[int] = Field(default=None) + tool_metadata: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + + # Relationships + conversation: "Conversation" = Relationship(back_populates="messages") diff --git a/src/models/task.py b/src/models/task.py index c32895b22f54a5ca1d21a533675a33dd2cd47001..11eec8e0f91dd84ee880ad1134db3954d710b49e 100644 --- a/src/models/task.py +++ b/src/models/task.py @@ -1,5 +1,5 @@ from sqlmodel import SQLModel, Field -from datetime import datetime +from datetime import datetime, date from typing import Optional @@ -13,5 +13,7 @@ class Task(SQLModel, table=True): title: str = Field(max_length=200, nullable=False) description: Optional[str] = Field(default=None, max_length=1000) completed: bool = Field(default=False, nullable=False, index=True) + due_date: Optional[date] = Field(default=None) + priority: str = Field(default="medium", max_length=20) created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False, index=True) updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False) diff --git a/src/schemas/__pycache__/auth.cpython-313.pyc b/src/schemas/__pycache__/auth.cpython-313.pyc index 4e3aa644d37ce86416f7d89289ac2be0e1f5db11..672d84d4b23969fd0b99cc1518b63ba191b222bf 100644 Binary files a/src/schemas/__pycache__/auth.cpython-313.pyc and b/src/schemas/__pycache__/auth.cpython-313.pyc differ diff --git a/src/schemas/__pycache__/chat_request.cpython-313.pyc b/src/schemas/__pycache__/chat_request.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9627b24e173364673b2bb9f3524d777110c8b6cb Binary files /dev/null and b/src/schemas/__pycache__/chat_request.cpython-313.pyc differ diff --git a/src/schemas/__pycache__/chat_response.cpython-313.pyc b/src/schemas/__pycache__/chat_response.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df13b4ace890b9af6b0c8c034abffbdc5ac069f2 Binary files /dev/null and b/src/schemas/__pycache__/chat_response.cpython-313.pyc differ diff --git a/src/schemas/__pycache__/conversation.cpython-313.pyc b/src/schemas/__pycache__/conversation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b017dba3eac266023dc4a6f9db25391ce0bf584 Binary files /dev/null and b/src/schemas/__pycache__/conversation.cpython-313.pyc differ diff --git a/src/schemas/__pycache__/error.cpython-313.pyc b/src/schemas/__pycache__/error.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..307723c59a307e6d475f22abc8fe2126d74c1c0b Binary files /dev/null and b/src/schemas/__pycache__/error.cpython-313.pyc differ diff --git a/src/schemas/chat_request.py b/src/schemas/chat_request.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0dc2b7f65ab2550634238ef0d6cbb8d7b1a51f --- /dev/null +++ b/src/schemas/chat_request.py @@ -0,0 +1,44 @@ +"""Pydantic schema for chat request.""" +from pydantic import BaseModel, Field +from typing import Optional + + +class ChatRequest(BaseModel): + """Request schema for chat endpoint. + + Represents a user's message to the AI chatbot. + """ + + message: str = Field( + ..., + min_length=1, + max_length=10000, + description="User's message to the AI assistant" + ) + + conversation_id: Optional[int] = Field( + default=None, + description="ID of existing conversation (null to start new conversation)" + ) + + system_prompt: Optional[str] = Field( + default=None, + max_length=5000, + description="Optional custom system prompt to override default" + ) + + temperature: float = Field( + default=0.7, + ge=0.0, + le=1.0, + description="Sampling temperature for response generation (0.0 to 1.0)" + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Can you help me organize my tasks for today?", + "conversation_id": 123, + "temperature": 0.7 + } + } diff --git a/src/schemas/chat_response.py b/src/schemas/chat_response.py new file mode 100644 index 0000000000000000000000000000000000000000..1493a5830180bfc2c48ad96cbf4211146a5d05bf --- /dev/null +++ b/src/schemas/chat_response.py @@ -0,0 +1,53 @@ +"""Pydantic schema for chat response.""" +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + + +class ChatResponse(BaseModel): + """Response schema for chat endpoint. + + Represents the AI assistant's response to a user's message. + """ + + conversation_id: int = Field( + ..., + description="ID of the conversation (new or existing)" + ) + + message: str = Field( + ..., + description="AI assistant's response message" + ) + + role: str = Field( + default="assistant", + description="Role of the message sender (always 'assistant' for responses)" + ) + + timestamp: datetime = Field( + ..., + description="Timestamp when the response was generated" + ) + + token_count: Optional[int] = Field( + default=None, + description="Number of tokens used in the response" + ) + + model: Optional[str] = Field( + default=None, + description="AI model used to generate the response" + ) + + class Config: + json_schema_extra = { + "example": { + "conversation_id": 123, + "message": "I'd be happy to help you organize your tasks! Let me know what you need to accomplish today.", + "role": "assistant", + "timestamp": "2026-01-14T10:30:00Z", + "token_count": 25, + "model": "gemini-pro" + } + } diff --git a/src/schemas/conversation.py b/src/schemas/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..9b003bba48f12f26fbe2ffc9388570eb0da4891a --- /dev/null +++ b/src/schemas/conversation.py @@ -0,0 +1,76 @@ +"""Conversation response schemas.""" +from datetime import datetime +from typing import List, Optional +from pydantic import BaseModel, Field + + +class MessageResponse(BaseModel): + """Response schema for a single message.""" + + id: int + role: str = Field(..., description="Role of the message sender (user or assistant)") + content: str = Field(..., description="Content of the message") + timestamp: datetime = Field(..., description="When the message was created") + token_count: Optional[int] = Field(None, description="Estimated token count") + + class Config: + from_attributes = True + + +class ConversationSummary(BaseModel): + """Summary of a conversation for list view.""" + + id: int + title: Optional[str] = Field(None, description="Title of the conversation") + created_at: datetime = Field(..., description="When the conversation was created") + updated_at: datetime = Field(..., description="When the conversation was last updated") + message_count: int = Field(0, description="Number of messages in the conversation") + last_message_preview: Optional[str] = Field(None, description="Preview of the last message (first 100 chars)") + + class Config: + from_attributes = True + + +class ConversationListResponse(BaseModel): + """Response schema for listing conversations.""" + + conversations: List[ConversationSummary] + total: int = Field(..., description="Total number of conversations") + + +class ConversationDetail(BaseModel): + """Detailed conversation information.""" + + id: int + user_id: int + title: Optional[str] = Field(None, description="Title of the conversation") + created_at: datetime = Field(..., description="When the conversation was created") + updated_at: datetime = Field(..., description="When the conversation was last updated") + + class Config: + from_attributes = True + + +class MessageListResponse(BaseModel): + """Response schema for conversation messages.""" + + conversation_id: int + messages: List[MessageResponse] + total: int = Field(..., description="Total number of messages") + + +class UpdateConversationRequest(BaseModel): + """Request schema for updating a conversation.""" + + title: str = Field(..., min_length=1, max_length=255, description="New title for the conversation") + + +class UpdateConversationResponse(BaseModel): + """Response schema for updating a conversation.""" + + id: int + title: Optional[str] + updated_at: datetime + + class Config: + from_attributes = True diff --git a/src/schemas/error.py b/src/schemas/error.py new file mode 100644 index 0000000000000000000000000000000000000000..8e72ed809c17cd516e7bc9fea4cd687cc11682b5 --- /dev/null +++ b/src/schemas/error.py @@ -0,0 +1,55 @@ +""" +Error response schemas for structured error handling. +""" + +from typing import Optional, Literal +from pydantic import BaseModel + + +class ErrorResponse(BaseModel): + """Structured error response model.""" + + error_code: str + detail: str + source: Literal["AI_PROVIDER", "AUTHENTICATION", "VALIDATION", "DATABASE", "INTERNAL"] + provider: Optional[str] = None + + class Config: + json_schema_extra = { + "example": { + "error_code": "RATE_LIMIT_EXCEEDED", + "detail": "AI service rate limit exceeded. Please wait a moment and try again.", + "source": "AI_PROVIDER", + "provider": "gemini" + } + } + + +# Error code constants +class ErrorCode: + """Standard error codes for the application.""" + + # AI Provider errors + RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" + API_KEY_MISSING = "API_KEY_MISSING" + API_KEY_INVALID = "API_KEY_INVALID" + PROVIDER_UNAVAILABLE = "PROVIDER_UNAVAILABLE" + PROVIDER_ERROR = "PROVIDER_ERROR" + + # Authentication errors + UNAUTHORIZED = "UNAUTHORIZED" + TOKEN_EXPIRED = "TOKEN_EXPIRED" + TOKEN_INVALID = "TOKEN_INVALID" + + # Validation errors + INVALID_INPUT = "INVALID_INPUT" + MESSAGE_TOO_LONG = "MESSAGE_TOO_LONG" + MESSAGE_EMPTY = "MESSAGE_EMPTY" + + # Database errors + CONVERSATION_NOT_FOUND = "CONVERSATION_NOT_FOUND" + DATABASE_ERROR = "DATABASE_ERROR" + + # Internal errors + INTERNAL_ERROR = "INTERNAL_ERROR" + UNKNOWN_ERROR = "UNKNOWN_ERROR" diff --git a/src/services/__pycache__/conversation_service.cpython-313.pyc b/src/services/__pycache__/conversation_service.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4001a5152b69023d80eb15f105b0a719ff0de5e7 Binary files /dev/null and b/src/services/__pycache__/conversation_service.cpython-313.pyc differ diff --git a/src/services/__pycache__/llm_service.cpython-313.pyc b/src/services/__pycache__/llm_service.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c1b25ea4fd69d8f63fa293afbcbf738cd22b05 Binary files /dev/null and b/src/services/__pycache__/llm_service.cpython-313.pyc differ diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb39da72b6492ce313332ea8ec109a4efa543bf --- /dev/null +++ b/src/services/conversation_service.py @@ -0,0 +1,240 @@ +"""Conversation service for CRUD operations.""" +from typing import List, Optional +from datetime import datetime +from sqlmodel import Session, select +from src.models.conversation import Conversation +from src.models.message import Message +from src.core.config import settings + + +class ConversationService: + """Service for managing conversations and messages. + + Handles CRUD operations for conversations and messages, + including conversation history retrieval and trimming. + """ + + def __init__(self, db: Session): + """Initialize the conversation service. + + Args: + db: SQLModel database session + """ + self.db = db + + def create_conversation(self, user_id: int, title: str | None = None) -> Conversation: + """Create a new conversation for a user. + + Args: + user_id: ID of the user creating the conversation + title: Optional title for the conversation + + Returns: + Created Conversation object + """ + conversation = Conversation( + user_id=user_id, + title=title or "New Conversation", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + self.db.add(conversation) + self.db.commit() + self.db.refresh(conversation) + return conversation + + def get_conversation(self, conversation_id: int, user_id: int) -> Optional[Conversation]: + """Get a conversation by ID, ensuring it belongs to the user. + + Args: + conversation_id: ID of the conversation + user_id: ID of the user (for authorization) + + Returns: + Conversation object if found and authorized, None otherwise + """ + statement = select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id + ) + return self.db.exec(statement).first() + + def get_user_conversations(self, user_id: int, limit: int = 50) -> List[Conversation]: + """Get all conversations for a user. + + Args: + user_id: ID of the user + limit: Maximum number of conversations to return + + Returns: + List of Conversation objects + """ + statement = ( + select(Conversation) + .where(Conversation.user_id == user_id) + .order_by(Conversation.updated_at.desc()) + .limit(limit) + ) + return list(self.db.exec(statement).all()) + + def add_message( + self, + conversation_id: int, + role: str, + content: str, + token_count: int | None = None + ) -> Message: + """Add a message to a conversation. + + Args: + conversation_id: ID of the conversation + role: Role of the message sender ("user" or "assistant") + content: Content of the message + token_count: Optional token count for the message + + Returns: + Created Message object + """ + message = Message( + conversation_id=conversation_id, + role=role, + content=content, + token_count=token_count, + timestamp=datetime.utcnow() + ) + self.db.add(message) + + # Update conversation's updated_at timestamp + conversation = self.db.get(Conversation, conversation_id) + if conversation: + conversation.updated_at = datetime.utcnow() + + self.db.commit() + self.db.refresh(message) + return message + + def get_conversation_messages( + self, + conversation_id: int, + limit: int | None = None + ) -> List[Message]: + """Get all messages for a conversation. + + Args: + conversation_id: ID of the conversation + limit: Optional limit on number of messages to return + + Returns: + List of Message objects ordered by timestamp + """ + statement = ( + select(Message) + .where(Message.conversation_id == conversation_id) + .order_by(Message.timestamp.asc()) + ) + + if limit: + statement = statement.limit(limit) + + return list(self.db.exec(statement).all()) + + def trim_conversation_history( + self, + conversation_id: int, + max_messages: int | None = None, + max_tokens: int | None = None + ) -> List[Message]: + """Trim conversation history based on message count and token limits. + + Implements hybrid trimming strategy: + 1. Keep most recent N messages (max_messages) + 2. Within those, ensure total tokens don't exceed max_tokens + + Args: + conversation_id: ID of the conversation + max_messages: Maximum number of messages to keep (default from settings) + max_tokens: Maximum total tokens to keep (default from settings) + + Returns: + List of trimmed Message objects + """ + max_messages = max_messages or settings.MAX_CONVERSATION_MESSAGES + max_tokens = max_tokens or settings.MAX_CONVERSATION_TOKENS + + # Get all messages + all_messages = self.get_conversation_messages(conversation_id) + + # Step 1: Keep only the most recent N messages + recent_messages = all_messages[-max_messages:] if len(all_messages) > max_messages else all_messages + + # Step 2: Trim by token count if needed + if max_tokens: + total_tokens = sum(msg.token_count or 0 for msg in recent_messages) + + # Remove oldest messages until under token limit + while total_tokens > max_tokens and len(recent_messages) > 1: + removed_message = recent_messages.pop(0) + total_tokens -= (removed_message.token_count or 0) + + return recent_messages + + def delete_conversation(self, conversation_id: int, user_id: int) -> bool: + """Delete a conversation and all its messages. + + Args: + conversation_id: ID of the conversation + user_id: ID of the user (for authorization) + + Returns: + True if deleted, False if not found or unauthorized + """ + conversation = self.get_conversation(conversation_id, user_id) + if not conversation: + return False + + self.db.delete(conversation) + self.db.commit() + return True + + def format_messages_for_agent( + self, + messages: List[Message], + max_messages: int = 20, + max_tokens: int = 8000 + ) -> List[dict]: + """Format messages for agent consumption with trimming. + + Converts Message objects to the format expected by the agent: + [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + + Applies conversation history trimming to stay within free-tier constraints: + 1. Keep only the most recent N messages (max_messages) + 2. Within those, ensure total tokens don't exceed max_tokens + + Args: + messages: List of Message objects from database + max_messages: Maximum number of messages to keep (default: 20) + max_tokens: Maximum total tokens to keep (default: 8000) + + Returns: + List of formatted message dicts for agent + """ + # Step 1: Keep only the most recent N messages + recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages + + # Step 2: Trim by token count if needed + if max_tokens: + total_tokens = sum(msg.token_count or 0 for msg in recent_messages) + + # Remove oldest messages until under token limit + while total_tokens > max_tokens and len(recent_messages) > 1: + removed_message = recent_messages.pop(0) + total_tokens -= (removed_message.token_count or 0) + + # Step 3: Convert to agent format + formatted_messages = [ + {"role": msg.role, "content": msg.content} + for msg in recent_messages + ] + + return formatted_messages diff --git a/src/services/llm_service.py b/src/services/llm_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a2372c7cb91ae123226d13534ff86c3b618a7a4d --- /dev/null +++ b/src/services/llm_service.py @@ -0,0 +1,171 @@ +"""LLM Service with provider factory pattern.""" +from typing import Dict, Any, List +from src.core.config import settings +from src.services.providers.base import LLMProvider +from src.services.providers.gemini import GeminiProvider + + +class LLMService: + """Service for managing LLM provider interactions. + + Implements a factory pattern to instantiate the correct provider + based on configuration. Handles provider selection and response generation. + """ + + _providers: Dict[str, type[LLMProvider]] = { + "gemini": GeminiProvider, + # Future providers can be added here: + # "openrouter": OpenRouterProvider, + # "cohere": CohereProvider, + } + + def __init__(self): + """Initialize the LLM service with configured provider.""" + self.provider = self._create_provider() + + def _create_provider(self) -> LLMProvider: + """Factory method to create the appropriate LLM provider. + + Returns: + Configured LLM provider instance + + Raises: + ValueError: If provider is not supported or API key is missing + """ + provider_name = settings.AI_PROVIDER.lower() + + if provider_name not in self._providers: + raise ValueError( + f"Unsupported AI provider: {provider_name}. " + f"Supported providers: {', '.join(self._providers.keys())}" + ) + + # Get API key based on provider + api_key = None + if provider_name == "gemini": + api_key = settings.GEMINI_API_KEY + elif provider_name == "openrouter": + api_key = settings.OPENROUTER_API_KEY + elif provider_name == "cohere": + api_key = settings.COHERE_API_KEY + + if not api_key: + raise ValueError( + f"API key not configured for provider: {provider_name}. " + f"Please set the appropriate API key in .env file." + ) + + # Instantiate provider + provider_class = self._providers[provider_name] + return provider_class(api_key=api_key) + + async def generate_response( + self, + messages: List[Dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float = 0.7 + ) -> Dict[str, Any]: + """Generate a response from the configured LLM provider. + + Args: + messages: List of message dicts with 'role' and 'content' keys + system_prompt: Optional system prompt to guide the AI's behavior + max_tokens: Maximum tokens to generate in the response + temperature: Sampling temperature (0.0 to 1.0) + + Returns: + Dict containing: + - content: The generated response text + - token_count: Number of tokens used + - model: Model name used + + Raises: + Exception: If the provider API call fails + """ + return await self.provider.generate_response( + messages=messages, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature + ) + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a text string. + + Args: + text: The text to count tokens for + + Returns: + Number of tokens in the text + """ + return self.provider.count_tokens(text) + + @staticmethod + def get_default_system_prompt() -> str: + """Get the default system prompt for the chatbot with intent recognition. + + Returns: + Default system prompt string with intent detection guidance + """ + return """You are a helpful AI assistant for a todo list application. +You can help users manage their tasks, answer questions, and provide assistance. + +## Intent Recognition +You should recognize and acknowledge the following todo-related intents: +- **Add Task**: User wants to create a new task (e.g., "add a task", "create a todo", "remind me to...") +- **Update Task**: User wants to modify an existing task (e.g., "update task", "change the title", "mark as complete") +- **Delete Task**: User wants to remove a task (e.g., "delete task", "remove todo", "cancel that") +- **List Tasks**: User wants to see their tasks (e.g., "show my tasks", "what do I need to do", "list todos") +- **General Help**: User needs assistance or has questions + +## Current Capabilities (Phase 1) +In Phase 1, you can engage in natural conversation but cannot yet perform task operations. +When users express todo-related intents, you should: +1. **Acknowledge** their intent clearly (e.g., "I understand you want to add a task...") +2. **Explain** that task management features will be available in Phase 2 +3. **Ask clarifying questions** if the request is ambiguous (e.g., "What would you like the task to be about?") +4. **Be encouraging** and let them know you're here to help once the feature is ready + +## Response Guidelines +- Be friendly, concise, and helpful +- Use natural, conversational language +- Ask clarifying questions when needed +- Show empathy and understanding +- Maintain context across the conversation""" + + @staticmethod + def get_intent_acknowledgment_template(intent: str) -> str: + """Get acknowledgment template for a specific intent. + + Args: + intent: The detected intent (add_task, update_task, delete_task, list_tasks) + + Returns: + Acknowledgment template string + """ + templates = { + "add_task": """I understand you want to add a new task! 📝 + +While I can't create tasks yet (this feature is coming in Phase 2), I'd be happy to help you plan it out. What would you like the task to be about?""", + + "update_task": """I see you want to update a task! ✏️ + +Task editing capabilities will be available in Phase 2. In the meantime, I can help you think through what changes you'd like to make.""", + + "delete_task": """I understand you want to remove a task! 🗑️ + +Task deletion will be available in Phase 2. For now, I can help you organize your thoughts about which tasks to keep or remove.""", + + "list_tasks": """I see you want to view your tasks! 📋 + +Task listing functionality will be available in Phase 2. Once it's ready, you'll be able to see all your tasks, filter them, and manage them easily.""", + + "general_help": """I'm here to help! 🤝 + +Right now, I can chat with you about your tasks and help you plan. Full task management features (add, update, delete, list) will be available in Phase 2. + +What would you like to know or discuss?""" + } + + return templates.get(intent, templates["general_help"]) diff --git a/src/services/providers/__pycache__/base.cpython-313.pyc b/src/services/providers/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c9bbf6593c61d6a9cb6e77afb928b5501bc3bf Binary files /dev/null and b/src/services/providers/__pycache__/base.cpython-313.pyc differ diff --git a/src/services/providers/__pycache__/gemini.cpython-313.pyc b/src/services/providers/__pycache__/gemini.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f383b8c42dcfb0cbfe18c61cf5a0f13eb85bfb77 Binary files /dev/null and b/src/services/providers/__pycache__/gemini.cpython-313.pyc differ diff --git a/src/services/providers/base.py b/src/services/providers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..70f0223fe0f4ebce65ec5ca89b99f0d3637571ff --- /dev/null +++ b/src/services/providers/base.py @@ -0,0 +1,60 @@ +"""Abstract base class for LLM providers.""" +from abc import ABC, abstractmethod +from typing import List, Dict, Any + + +class LLMProvider(ABC): + """Abstract base class for AI language model providers. + + All provider implementations (Gemini, OpenRouter, Cohere) must inherit from this class + and implement the generate_response method. + """ + + def __init__(self, api_key: str, model_name: str): + """Initialize the LLM provider. + + Args: + api_key: API key for the provider + model_name: Name of the model to use + """ + self.api_key = api_key + self.model_name = model_name + + @abstractmethod + async def generate_response( + self, + messages: List[Dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float = 0.7 + ) -> Dict[str, Any]: + """Generate a response from the AI model. + + Args: + messages: List of message dicts with 'role' and 'content' keys + system_prompt: Optional system prompt to guide the AI's behavior + max_tokens: Maximum tokens to generate in the response + temperature: Sampling temperature (0.0 to 1.0) + + Returns: + Dict containing: + - content: The generated response text + - token_count: Number of tokens used (if available) + - model: Model name used + + Raises: + Exception: If the API call fails + """ + pass + + @abstractmethod + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a text string. + + Args: + text: The text to count tokens for + + Returns: + Number of tokens in the text + """ + pass diff --git a/src/services/providers/gemini.py b/src/services/providers/gemini.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaef045b3a939f8daab7c38cde81565e323827c --- /dev/null +++ b/src/services/providers/gemini.py @@ -0,0 +1,116 @@ +"""Gemini AI provider implementation.""" +import google.generativeai as genai +from typing import List, Dict, Any +from .base import LLMProvider + + +class GeminiProvider(LLMProvider): + """Google Gemini AI provider implementation. + + Uses the google-generativeai library to interact with Gemini models. + Supports gemini-pro and other Gemini model variants. + """ + + def __init__(self, api_key: str, model_name: str = "google/gemini-2.0-flash-exp:free"): + """Initialize the Gemini provider. + + Args: + api_key: Google AI API key + model_name: Gemini model name (default: gemini-pro) + """ + super().__init__(api_key, model_name) + genai.configure(api_key=api_key) + self.model = genai.GenerativeModel(model_name) + + async def generate_response( + self, + messages: List[Dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float = 0.7 + ) -> Dict[str, Any]: + """Generate a response from Gemini. + + Args: + messages: List of message dicts with 'role' and 'content' keys + system_prompt: Optional system prompt to guide the AI's behavior + max_tokens: Maximum tokens to generate in the response + temperature: Sampling temperature (0.0 to 1.0) + + Returns: + Dict containing: + - content: The generated response text + - token_count: Number of tokens used (estimated) + - model: Model name used + + Raises: + Exception: If the Gemini API call fails + """ + try: + # Build the conversation history for Gemini + # Gemini expects a list of content parts + chat_history = [] + + # Add system prompt if provided + if system_prompt: + chat_history.append({ + "role": "user", + "parts": [system_prompt] + }) + chat_history.append({ + "role": "model", + "parts": ["Understood. I will follow these instructions."] + }) + + # Convert messages to Gemini format + for msg in messages: + role = "model" if msg["role"] == "assistant" else "user" + chat_history.append({ + "role": role, + "parts": [msg["content"]] + }) + + # Start chat with history + chat = self.model.start_chat(history=chat_history[:-1]) # Exclude last message + + # Generate response + generation_config = genai.types.GenerationConfig( + temperature=temperature, + max_output_tokens=max_tokens + ) + + response = chat.send_message( + chat_history[-1]["parts"][0], + generation_config=generation_config + ) + + # Extract response content + content = response.text + + # Estimate token count (Gemini doesn't provide exact counts in free tier) + token_count = self.count_tokens(content) + + return { + "content": content, + "token_count": token_count, + "model": self.model_name + } + + except Exception as e: + raise Exception(f"Gemini API error: {str(e)}") + + def count_tokens(self, text: str) -> int: + """Count the number of tokens in a text string. + + Uses a simple estimation: ~4 characters per token (rough approximation). + For more accurate counting, consider using tiktoken library. + + Args: + text: The text to count tokens for + + Returns: + Estimated number of tokens in the text + """ + # Simple estimation: ~4 characters per token + # This is a rough approximation for English text + return len(text) // 4 diff --git a/test_error_handling.py b/test_error_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1909084e571118e4b8d9930304b6fe808e317e --- /dev/null +++ b/test_error_handling.py @@ -0,0 +1,185 @@ +""" +Test script to verify comprehensive error handling implementation. + +This script tests: +1. Backend structured error responses with error codes +2. Correct HTTP status codes (401, 429, 503, 500) +3. Error classification (missing API key, invalid key, rate limit, provider error) +""" + +import sys +import os + +# Add backend src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from core.exceptions import ( + classify_ai_error, + RateLimitExceededException, + APIKeyMissingException, + APIKeyInvalidException, + ProviderUnavailableException, + ProviderErrorException +) +from schemas.error import ErrorCode + + +def test_error_classification(): + """Test that errors are correctly classified.""" + print("=" * 60) + print("Testing Error Classification") + print("=" * 60) + + # Test 1: Rate limit error + print("\n1. Testing rate limit error classification:") + rate_limit_error = Exception("Error: 429 rate limit exceeded") + classified = classify_ai_error(rate_limit_error, provider="gemini") + assert isinstance(classified, RateLimitExceededException) + assert classified.error_code == ErrorCode.RATE_LIMIT_EXCEEDED + assert classified.status_code == 429 + assert classified.provider == "gemini" + print(f" [OK] Correctly classified as RateLimitExceededException") + print(f" [OK] Error code: {classified.error_code}") + print(f" [OK] Status code: {classified.status_code}") + print(f" [OK] Detail: {classified.detail}") + + # Test 2: API key missing error + print("\n2. Testing API key missing error classification:") + missing_key_error = Exception("API key not found") + classified = classify_ai_error(missing_key_error, provider="openrouter") + assert isinstance(classified, APIKeyMissingException) + assert classified.error_code == ErrorCode.API_KEY_MISSING + assert classified.status_code == 503 + print(f" [OK] Correctly classified as APIKeyMissingException") + print(f" [OK] Error code: {classified.error_code}") + print(f" [OK] Status code: {classified.status_code}") + + # Test 3: API key invalid error + print("\n3. Testing API key invalid error classification:") + invalid_key_error = Exception("401 unauthorized - invalid api key") + classified = classify_ai_error(invalid_key_error, provider="cohere") + assert isinstance(classified, APIKeyInvalidException) + assert classified.error_code == ErrorCode.API_KEY_INVALID + assert classified.status_code == 401 + print(f" [OK] Correctly classified as APIKeyInvalidException") + print(f" [OK] Error code: {classified.error_code}") + print(f" [OK] Status code: {classified.status_code}") + + # Test 4: Provider unavailable error + print("\n4. Testing provider unavailable error classification:") + unavailable_error = Exception("503 service unavailable") + classified = classify_ai_error(unavailable_error, provider="gemini") + assert isinstance(classified, ProviderUnavailableException) + assert classified.error_code == ErrorCode.PROVIDER_UNAVAILABLE + assert classified.status_code == 503 + print(f" [OK] Correctly classified as ProviderUnavailableException") + print(f" [OK] Error code: {classified.error_code}") + print(f" [OK] Status code: {classified.status_code}") + + # Test 5: Generic provider error + print("\n5. Testing generic provider error classification:") + generic_error = Exception("Something went wrong with the AI service") + classified = classify_ai_error(generic_error, provider="gemini") + assert isinstance(classified, ProviderErrorException) + assert classified.error_code == ErrorCode.PROVIDER_ERROR + assert classified.status_code == 500 + print(f" [OK] Correctly classified as ProviderErrorException") + print(f" [OK] Error code: {classified.error_code}") + print(f" [OK] Status code: {classified.status_code}") + + print("\n" + "=" * 60) + print("[OK] All error classification tests passed!") + print("=" * 60) + + +def test_error_response_structure(): + """Test that error responses have correct structure.""" + print("\n" + "=" * 60) + print("Testing Error Response Structure") + print("=" * 60) + + # Create a sample error + error = RateLimitExceededException(provider="gemini") + + print("\n1. Testing error attributes:") + print(f" [OK] error_code: {error.error_code}") + print(f" [OK] detail: {error.detail}") + print(f" [OK] source: {error.source}") + print(f" [OK] provider: {error.provider}") + print(f" [OK] status_code: {error.status_code}") + + # Verify all required attributes exist + assert hasattr(error, 'error_code') + assert hasattr(error, 'detail') + assert hasattr(error, 'source') + assert hasattr(error, 'provider') + assert hasattr(error, 'status_code') + + # Verify source is correct + assert error.source == "AI_PROVIDER" + + print("\n" + "=" * 60) + print("[OK] All error response structure tests passed!") + print("=" * 60) + + +def test_error_codes(): + """Test that all error codes are defined.""" + print("\n" + "=" * 60) + print("Testing Error Code Constants") + print("=" * 60) + + print("\n1. AI Provider error codes:") + print(f" [OK] RATE_LIMIT_EXCEEDED: {ErrorCode.RATE_LIMIT_EXCEEDED}") + print(f" [OK] API_KEY_MISSING: {ErrorCode.API_KEY_MISSING}") + print(f" [OK] API_KEY_INVALID: {ErrorCode.API_KEY_INVALID}") + print(f" [OK] PROVIDER_UNAVAILABLE: {ErrorCode.PROVIDER_UNAVAILABLE}") + print(f" [OK] PROVIDER_ERROR: {ErrorCode.PROVIDER_ERROR}") + + print("\n2. Authentication error codes:") + print(f" [OK] UNAUTHORIZED: {ErrorCode.UNAUTHORIZED}") + print(f" [OK] TOKEN_EXPIRED: {ErrorCode.TOKEN_EXPIRED}") + print(f" [OK] TOKEN_INVALID: {ErrorCode.TOKEN_INVALID}") + + print("\n3. Validation error codes:") + print(f" [OK] INVALID_INPUT: {ErrorCode.INVALID_INPUT}") + print(f" [OK] MESSAGE_TOO_LONG: {ErrorCode.MESSAGE_TOO_LONG}") + print(f" [OK] MESSAGE_EMPTY: {ErrorCode.MESSAGE_EMPTY}") + + print("\n4. Database error codes:") + print(f" [OK] CONVERSATION_NOT_FOUND: {ErrorCode.CONVERSATION_NOT_FOUND}") + print(f" [OK] DATABASE_ERROR: {ErrorCode.DATABASE_ERROR}") + + print("\n5. Internal error codes:") + print(f" [OK] INTERNAL_ERROR: {ErrorCode.INTERNAL_ERROR}") + print(f" [OK] UNKNOWN_ERROR: {ErrorCode.UNKNOWN_ERROR}") + + print("\n" + "=" * 60) + print("[OK] All error codes are properly defined!") + print("=" * 60) + + +if __name__ == "__main__": + try: + test_error_classification() + test_error_response_structure() + test_error_codes() + + print("\n" + "=" * 60) + print("[SUCCESS] ALL TESTS PASSED!") + print("=" * 60) + print("\nError handling implementation is working correctly:") + print("[OK] Backend structured error responses with error codes") + print("[OK] Correct HTTP status codes (401, 429, 503, 500)") + print("[OK] Error classification (missing API key, invalid key, rate limit, provider error)") + print("[OK] Clear error identification and user-friendly messages") + print("=" * 60) + + except AssertionError as e: + print(f"\n[ERROR] Test failed: {e}") + sys.exit(1) + except Exception as e: + print(f"\n[ERROR] Unexpected error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) diff --git a/tmpclaude-aadc-cwd b/tmpclaude-aadc-cwd new file mode 100644 index 0000000000000000000000000000000000000000..4dd998934043f027c9dbce34fb3b8a656e3bf8b8 --- /dev/null +++ b/tmpclaude-aadc-cwd @@ -0,0 +1 @@ +/d/Agentic_ai_learning/hacathoon_2/evolution-of-todo/phase-2-full-stack-web-app/backend diff --git a/tmpclaude-dc39-cwd b/tmpclaude-dc39-cwd new file mode 100644 index 0000000000000000000000000000000000000000..4dd998934043f027c9dbce34fb3b8a656e3bf8b8 --- /dev/null +++ b/tmpclaude-dc39-cwd @@ -0,0 +1 @@ +/d/Agentic_ai_learning/hacathoon_2/evolution-of-todo/phase-2-full-stack-web-app/backend