diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..bc2dfd0201b057b95ca2fc1f212d39862fa78bc7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,55 @@ +# Python dependencies +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Environment +.env* +!.env.example + +# Tests +.pytest_cache/ +.coverage +htmlcov/ +*.cover + +# Debug logs +*.log + +# UV +.uv/ +uv.lock diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..549747aa3adacc0b938ad698e45a07805af04254 --- /dev/null +++ b/.env @@ -0,0 +1,15 @@ +# Database Configuration +DATABASE_URL=postgresql://neondb_owner:npg_LsojKQF8bGn2@ep-mute-pine-a4g0wfsu-pooler.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require + +# JWT Configuration +BETTER_AUTH_SECRET=your-secret-key-change-in-production +JWT_SECRET_KEY=your-jwt-secret-change-in-production +JWT_ALGORITHM=HS256 +ACCESS_TOKEN_EXPIRE_DAYS=7 +JWT_COOKIE_SECURE=True +JWT_COOKIE_SAMESITE=none + +# CORS Configuration +FRONTEND_URL=https://task-flow-roan-beta.vercel.app,http://localhost:3000,http://127.0.0.1:3000 +OPENAI_API_KEY=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 +GEMINI_API_KEY=AIzaSyDcrSw3MIP0f4uJAf8Ol6M2BB4KUpkBRqI \ No newline at end of file diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000000000000000000000000000000000..e4fba2183587225f216eeada4c78dfab6b2e65f5 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000000000000000000000000000000000..e6054ee6cee45e5aa0ea1873e0f959d4d0577859 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,273 @@ +# Claude Agent Instructions - Backend + +## Context + +You are working in the **FastAPI backend** of a full-stack task management application. + +**Parent Instructions**: See root `CLAUDE.md` for global rules. + +## Technology Stack + +- **FastAPI** 0.115+ +- **SQLModel** 0.0.24+ (NOT raw SQLAlchemy) +- **Pydantic v2** for validation +- **PostgreSQL 16** via Neon +- **UV** package manager +- **Alembic** for migrations +- **Python 3.13+** + +## Critical Requirements + +### SQLModel (NOT SQLAlchemy) + +**Correct** (SQLModel): +```python +from sqlmodel import SQLModel, Field, Relationship + +class User(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + email: str = Field(unique=True, index=True) + password_hash: str + + tasks: list["Task"] = Relationship(back_populates="owner") +``` + +**Forbidden** (raw SQLAlchemy): +```python +from sqlalchemy import Column, Integer, String # NO! +``` + +### User Data Isolation (CRITICAL) + +**ALWAYS filter by user_id**: +```python +from fastapi import Depends, HTTPException +from sqlmodel import select + +async def get_user_tasks( + user_id: int, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session) +): + # Verify ownership + if user_id != current_user.id: + raise HTTPException(status_code=404) # NOT 403! + + # Filter by user_id + statement = select(Task).where(Task.user_id == user_id) + tasks = session.exec(statement).all() + return tasks +``` + +### JWT Authentication + +**Token Validation**: +```python +from jose import jwt, JWTError +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBearer + +security = HTTPBearer() + +async def get_current_user( + token: str = Depends(security) +) -> User: + try: + payload = jwt.decode( + token.credentials, + settings.BETTER_AUTH_SECRET, + algorithms=[settings.JWT_ALGORITHM] + ) + user_id: int = payload.get("sub") + if user_id is None: + raise HTTPException(status_code=401) + except JWTError: + raise HTTPException(status_code=401) + + user = get_user_from_db(user_id) + if user is None: + raise HTTPException(status_code=401) + return user +``` + +### Password Security + +```python +from passlib.context import CryptContext + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +def hash_password(password: str) -> str: + return pwd_context.hash(password) + +def verify_password(plain: str, hashed: str) -> bool: + return pwd_context.verify(plain, hashed) +``` + +## Project Structure + +``` +src/ +├── main.py # FastAPI app, CORS, startup +├── config.py # Environment variables +├── database.py # SQLModel engine, session +├── models/ +│ ├── user.py # User SQLModel +│ └── task.py # Task SQLModel +├── schemas/ +│ ├── auth.py # Request/response schemas +│ └── task.py # Request/response schemas +├── routers/ +│ ├── auth.py # /api/auth/* endpoints +│ └── tasks.py # /api/{user_id}/tasks/* endpoints +├── middleware/ +│ └── auth.py # JWT validation +└── utils/ + ├── security.py # bcrypt, JWT helpers + └── deps.py # Dependency injection +``` + +## API Patterns + +### Endpoint Structure + +```python +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session + +router = APIRouter(prefix="/api/{user_id}/tasks", tags=["tasks"]) + +@router.get("/", response_model=list[TaskResponse]) +async def list_tasks( + user_id: int, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session) +): + # Authorization check + if user_id != current_user.id: + raise HTTPException(status_code=404) + + # Query with user_id filter + statement = select(Task).where(Task.user_id == user_id) + tasks = session.exec(statement).all() + return tasks +``` + +### Error Responses + +```python +# 401 Unauthorized - Invalid/missing JWT +raise HTTPException( + status_code=401, + detail="Invalid authentication credentials" +) + +# 404 Not Found - Resource doesn't exist OR unauthorized access +raise HTTPException( + status_code=404, + detail="Task not found" +) + +# 400 Bad Request - Validation error +raise HTTPException( + status_code=400, + detail="Title must be between 1-200 characters" +) + +# 409 Conflict - Duplicate resource +raise HTTPException( + status_code=409, + detail="An account with this email already exists" +) +``` + +## Database Migrations + +**Creating Migrations**: +```bash +uv run alembic revision --autogenerate -m "Add users and tasks tables" +``` + +**Applying Migrations**: +```bash +uv run alembic upgrade head +``` + +**Migration File Structure**: +```python +def upgrade(): + op.create_table( + 'user', + sa.Column('id', sa.Integer(), primary_key=True), + sa.Column('email', sa.String(), unique=True), + sa.Column('password_hash', sa.String()), + ) + op.create_index('ix_user_email', 'user', ['email']) +``` + +## Testing + +**Fixtures** (`tests/conftest.py`): +```python +import pytest +from sqlmodel import Session, create_engine +from fastapi.testclient import TestClient + +@pytest.fixture +def session(): + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + +@pytest.fixture +def client(session): + app.dependency_overrides[get_session] = lambda: session + yield TestClient(app) +``` + +**Test Example**: +```python +def test_create_task(client, auth_headers): + response = client.post( + "/api/1/tasks", + headers=auth_headers, + json={"title": "Test Task", "description": "Test"} + ) + assert response.status_code == 201 + assert response.json()["title"] == "Test Task" +``` + +## Environment Variables + +Required in `.env`: +``` +DATABASE_URL=postgresql://taskuser:taskpassword@db:5432/taskdb +BETTER_AUTH_SECRET=your-secret-key-change-in-production +JWT_SECRET_KEY=your-jwt-secret-change-in-production +JWT_ALGORITHM=HS256 +ACCESS_TOKEN_EXPIRE_DAYS=7 +``` + +## Common Mistakes to Avoid + +❌ Using raw SQLAlchemy instead of SQLModel +✅ Use SQLModel for all database models + +❌ Trusting user_id from request parameters +✅ Always extract from validated JWT token + +❌ Returning 403 for unauthorized access +✅ Return 404 to prevent information leakage + +❌ SQL string concatenation +✅ SQLModel parameterized queries only + +❌ Plaintext passwords +✅ bcrypt hashing always + +## References + +- Root Instructions: `../CLAUDE.md` +- Feature Spec: `../specs/001-task-crud-auth/spec.md` +- Constitution: `../.specify/memory/constitution.md` \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..025f30e6a7b88fbbfa658d19a52f107457c2562f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,65 @@ +# Multi-stage Dockerfile for FastAPI backend +# Stage 1: Builder - Install dependencies and prepare the application +FROM python:3.13-slim AS builder + +# Set working directory +WORKDIR /app + +# Install system dependencies for building +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install UV for fast dependency management +RUN pip install --no-cache-dir uv + +# Copy pyproject.toml first for better layer caching +COPY pyproject.toml uv.lock* ./ + +# Create virtual environment and install dependencies +# Note: UV_SYSTEM_PYTHON=0 ensures packages are installed to venv, not system +ENV UV_SYSTEM_PYTHON=0 +RUN uv venv /app/.venv && \ + . /app/.venv/bin/activate && \ + uv pip install -r pyproject.toml + +# Copy the rest of the application +COPY . . + +# Stage 2: Runner - Production-ready image +FROM python:3.13-slim AS runner + +# Create non-root user for security +RUN groupadd --system --gid 1001 appgroup && \ + useradd --system --uid 1001 --gid appgroup --shell /bin/false --create-home appuser + +# Set working directory +WORKDIR /app + +# Copy virtual environment from builder +COPY --from=builder /app/.venv /app/.venv + +# Copy application code +COPY --from=builder /app/src /app/src +COPY --from=builder /app/pyproject.toml /app/pyproject.toml + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Set ownership +RUN chown -R appuser:appgroup /app + +# Use non-root user +USER appuser + +# Expose the application port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ + CMD wget --no-verbose --tries=1 --spider http://localhost:${PORT:-8000}/health || exit 1 + +# Start the application using explicit path to uvicorn with dynamic port support +# Default to port 8000 if PORT env var is not set +CMD ["/bin/sh", "-c", "/app/.venv/bin/uvicorn src.main:app --host 0.0.0.0 --port ${PORT:-8000}"] diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000000000000000000000000000000000000..9f6c3ea9fcc033a267e75b97eb4814d6ccc2e471 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = postgresql://neondb_owner:npg_LsojKQF8bGn2@ep-mute-pine-a4g0wfsu-pooler.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000000000000000000000000000000000000..98e4f9c44effe479ed38c66ba922e7bcc672916f --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000000000000000000000000000000000000..ba0a9d6f8ca605bb0aad385da6f053bb71b2e6a3 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,90 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from alembic import context + +# Import SQLModel and models +from sqlmodel import SQLModel + +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +from src.models.user import User # Import your models +from src.models.task import Task # Import your models +from src.models.project import Project # Import your models +from src.models.conversation import Conversation # Import your models +from src.models.message import Message # Import your models +from src.models.audit_log import AuditLog # Import your models + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +target_metadata = SQLModel.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() \ No newline at end of file diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000000000000000000000000000000000000..11016301e749297acb67822efc7974ee53c905c6 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/3b6c60669e48_add_project_model_and_relationship_to_.py b/alembic/versions/3b6c60669e48_add_project_model_and_relationship_to_.py new file mode 100644 index 0000000000000000000000000000000000000000..6e80e783bede0b40a85fcd19d45e3a4826e0c935 --- /dev/null +++ b/alembic/versions/3b6c60669e48_add_project_model_and_relationship_to_.py @@ -0,0 +1,52 @@ +"""Add Project model and relationship to Task + +Revision ID: 3b6c60669e48 +Revises: ec70eaafa7b6 +Create Date: 2025-12-19 03:46:01.389687 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '3b6c60669e48' +down_revision: Union[str, Sequence[str], None] = 'ec70eaafa7b6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('project', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(length=200), nullable=False), + sa.Column('description', sa.String(length=1000), nullable=True), + sa.Column('color', sa.String(length=7), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('deadline', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_project_user_id'), 'project', ['user_id'], unique=False) + op.add_column('task', sa.Column('project_id', sa.Uuid(), nullable=True)) + op.create_index(op.f('ix_task_project_id'), 'task', ['project_id'], unique=False) + op.create_foreign_key(None, 'task', 'project', ['project_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'task', type_='foreignkey') + op.drop_index(op.f('ix_task_project_id'), table_name='task') + op.drop_column('task', 'project_id') + op.drop_index(op.f('ix_project_user_id'), table_name='project') + op.drop_table('project') + # ### end Alembic commands ### \ No newline at end of file diff --git a/alembic/versions/4ac448e3f100_add_due_date_field_to_task_model.py b/alembic/versions/4ac448e3f100_add_due_date_field_to_task_model.py new file mode 100644 index 0000000000000000000000000000000000000000..10712c46bb989bb92a3a63d6292f60a808c96967 --- /dev/null +++ b/alembic/versions/4ac448e3f100_add_due_date_field_to_task_model.py @@ -0,0 +1,32 @@ +"""Add due_date field to Task model + +Revision ID: 4ac448e3f100 +Revises: 3b6c60669e48 +Create Date: 2025-12-19 03:50:35.687835 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4ac448e3f100' +down_revision: Union[str, Sequence[str], None] = '3b6c60669e48' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('task', sa.Column('due_date', sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('task', 'due_date') + # ### end Alembic commands ### diff --git a/alembic/versions/6f0b6403a1d8_add_refresh_token_table.py b/alembic/versions/6f0b6403a1d8_add_refresh_token_table.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b3811ce5c554b389fed06f41281d8b7c349321 --- /dev/null +++ b/alembic/versions/6f0b6403a1d8_add_refresh_token_table.py @@ -0,0 +1,43 @@ +"""add refresh token table + +Revision ID: 6f0b6403a1d8 +Revises: 4ac448e3f100 +Create Date: 2025-12-24 01:48:38.858071 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6f0b6403a1d8' +down_revision: Union[str, Sequence[str], None] = '4ac448e3f100' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Create refresh_tokens table + op.create_table('refresh_tokens', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('token', sa.String(), nullable=False), + sa.Column('expires_at', sa.DateTime(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('token') + ) + op.create_index(op.f('ix_refresh_tokens_user_id'), 'refresh_tokens', ['user_id'], unique=False) + op.create_index(op.f('ix_refresh_tokens_token'), 'refresh_tokens', ['token'], unique=True) + + +def downgrade() -> None: + """Downgrade schema.""" + # Drop refresh_tokens table + op.drop_index(op.f('ix_refresh_tokens_token'), table_name='refresh_tokens') + op.drop_index(op.f('ix_refresh_tokens_user_id'), table_name='refresh_tokens') + op.drop_table('refresh_tokens') diff --git a/alembic/versions/8e3b5a7c2d9f_add_conversation_message_tables.py b/alembic/versions/8e3b5a7c2d9f_add_conversation_message_tables.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf4fd2d70b15cf48c8145de11350bc93be346ff --- /dev/null +++ b/alembic/versions/8e3b5a7c2d9f_add_conversation_message_tables.py @@ -0,0 +1,59 @@ +"""add conversation and message tables + +Revision ID: 8e3b5a7c2d9f +Revises: 6f0b6403a1d8 +Create Date: 2025-12-24 01:50:45.930037 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8e3b5a7c2d9f' +down_revision: Union[str, Sequence[str], None] = '6f0b6403a1d8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = '6f0b6403a1d8' + + +def upgrade() -> None: + """Upgrade schema.""" + # Create conversation table + op.create_table('conversation', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_conversation_user_id'), 'conversation', ['user_id'], unique=False) + + # Create message table + op.create_table('message', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('conversation_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('role', sa.Enum('user', 'assistant', name='message_role'), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['conversation_id'], ['conversation.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_message_conversation_id'), 'message', ['conversation_id'], unique=False) + + +def downgrade() -> None: + """Downgrade schema.""" + # Drop message table + op.drop_index(op.f('ix_message_conversation_id'), table_name='message') + op.drop_table('message') + + # Drop conversation table + op.drop_index(op.f('ix_conversation_user_id'), table_name='conversation') + op.drop_table('conversation') + + # Drop enum type + op.execute('DROP TYPE IF EXISTS message_role;') diff --git a/alembic/versions/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.py b/alembic/versions/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..53a0cb9b1f1622de1b4b31ca991af01f41885017 --- /dev/null +++ b/alembic/versions/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.py @@ -0,0 +1,27 @@ +"""add is_ai_generated to tasks + +Revision ID: 9a4b8c7d1e2f +Revises: 8e3b5a7c2d9f +Create Date: 2025-12-25 05:47:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +import uuid + + +# revision identifiers +revision = '9a4b8c7d1e2f' +down_revision = '8e3b5a7c2d9f' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add the is_ai_generated column to the tasks table + op.add_column('task', sa.Column('is_ai_generated', sa.Boolean(), nullable=False, server_default='false')) + + +def downgrade(): + # Remove the is_ai_generated column from the tasks table + op.drop_column('task', 'is_ai_generated') \ No newline at end of file diff --git a/alembic/versions/__pycache__/3b6c60669e48_add_project_model_and_relationship_to_.cpython-312.pyc b/alembic/versions/__pycache__/3b6c60669e48_add_project_model_and_relationship_to_.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10f1538bbbe000f1573fc3b184de2d6e398165ee Binary files /dev/null and b/alembic/versions/__pycache__/3b6c60669e48_add_project_model_and_relationship_to_.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/4ac448e3f100_add_due_date_field_to_task_model.cpython-312.pyc b/alembic/versions/__pycache__/4ac448e3f100_add_due_date_field_to_task_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eceb29b664c2b06dfead29305ad4a6cd082e16bd Binary files /dev/null and b/alembic/versions/__pycache__/4ac448e3f100_add_due_date_field_to_task_model.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/6f0b6403a1d8_add_refresh_token_table.cpython-312.pyc b/alembic/versions/__pycache__/6f0b6403a1d8_add_refresh_token_table.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d3ad0456c1e3076f0b678751291665a81b2c8a Binary files /dev/null and b/alembic/versions/__pycache__/6f0b6403a1d8_add_refresh_token_table.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/8e3b5a7c2d9f_add_conversation_message_tables.cpython-312.pyc b/alembic/versions/__pycache__/8e3b5a7c2d9f_add_conversation_message_tables.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9379034a38a027a870c5fe9d16cff1be0f627d3 Binary files /dev/null and b/alembic/versions/__pycache__/8e3b5a7c2d9f_add_conversation_message_tables.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.cpython-312.pyc b/alembic/versions/__pycache__/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4880c0302e4f37c5c4a5cc51ff6ce4708faa9caf Binary files /dev/null and b/alembic/versions/__pycache__/9a4b8c7d1e2f_add_is_ai_generated_to_tasks.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/a1b2c3d4e5f6_add_audit_log_table.cpython-312.pyc b/alembic/versions/__pycache__/a1b2c3d4e5f6_add_audit_log_table.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ac29721db0799ea55ec5892d2e2e24b2b9f217 Binary files /dev/null and b/alembic/versions/__pycache__/a1b2c3d4e5f6_add_audit_log_table.cpython-312.pyc differ diff --git a/alembic/versions/__pycache__/ec70eaafa7b6_initial_schema_with_users_and_tasks_.cpython-312.pyc b/alembic/versions/__pycache__/ec70eaafa7b6_initial_schema_with_users_and_tasks_.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..224213cf992bfa445d68e7deb36134ba16838453 Binary files /dev/null and b/alembic/versions/__pycache__/ec70eaafa7b6_initial_schema_with_users_and_tasks_.cpython-312.pyc differ diff --git a/alembic/versions/a1b2c3d4e5f6_add_audit_log_table.py b/alembic/versions/a1b2c3d4e5f6_add_audit_log_table.py new file mode 100644 index 0000000000000000000000000000000000000000..3a632d539b49494d33d96b451aebb81bca9e94a4 --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_add_audit_log_table.py @@ -0,0 +1,43 @@ +"""Add audit_log table + +Revision ID: a1b2c3d4e5f6 +Revises: 9a4b8c7d1e2f +Create Date: 2026-01-31 01:00:00.000000 + +""" +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa +import sqlmodel +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'a1b2c3d4e5f6' +down_revision: Union[str, None] = '9a4b8c7d1e2f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create audit_log table + op.create_table('auditlog', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('event_id', sa.String(), nullable=False), + sa.Column('event_type', sa.String(length=50), nullable=False), + sa.Column('user_id', sa.String(), nullable=False), + sa.Column('task_id', sa.Integer(), nullable=False), + sa.Column('event_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('event_id'), + sa.Index('ix_auditlog_event_id', 'event_id'), + sa.Index('ix_auditlog_user_id', 'user_id'), + sa.Index('ix_auditlog_task_id', 'task_id'), + sa.Index('ix_auditlog_event_type', 'event_type'), + sa.Index('ix_auditlog_timestamp', 'timestamp') + ) + + +def downgrade() -> None: + # Drop audit_log table + op.drop_table('auditlog') diff --git a/alembic/versions/ec70eaafa7b6_initial_schema_with_users_and_tasks_.py b/alembic/versions/ec70eaafa7b6_initial_schema_with_users_and_tasks_.py new file mode 100644 index 0000000000000000000000000000000000000000..89a2b88b7b3a86e9cd85acf5d3b2697ebebbb6a4 --- /dev/null +++ b/alembic/versions/ec70eaafa7b6_initial_schema_with_users_and_tasks_.py @@ -0,0 +1,54 @@ +"""Initial schema with users and tasks tables + +Revision ID: ec70eaafa7b6 +Revises: +Create Date: 2025-12-16 05:07:24.251683 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel +# revision identifiers, used by Alembic. +revision: str = 'ec70eaafa7b6' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('email', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column('password_hash', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_email'), 'user', ['email'], unique=True) + op.create_table('task', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('title', sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=True), + sa.Column('completed', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_task_user_id'), 'task', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_task_user_id'), table_name='task') + op.drop_table('task') + op.drop_index(op.f('ix_user_email'), table_name='user') + op.drop_table('user') + # ### end Alembic commands ### diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9d76cc45981c02e26cfffc110824068568630650 --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from task-api!") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6aa9b2d1f29d3742001fd73096b968615284189a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +alembic>=1.17.2 +fastapi>=0.124.4 +passlib[bcrypt]>=1.7.4 +psycopg2-binary>=2.9.11 +pydantic-settings>=2.12.0 +pydantic[email]>=2.12.5 +python-jose[cryptography]>=3.5.0 +python-multipart>=0.0.20 +sqlmodel>=0.0.27 +uvicorn>=0.38.0 +httpx>=0.28.1 +pytest>=9.0.2 +pytest-asyncio>=1.3.0 +python-dotenv>=1.0.1 +bcrypt>=3.1.3,<4.0.0 +cryptography>=45.0.0 +dapr>=1.13.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/agent_config.py b/src/agent_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4ed51c2ccab477ba10f4feb82eb68c3801a303 --- /dev/null +++ b/src/agent_config.py @@ -0,0 +1,54 @@ +import os +from openai import OpenAI +from dotenv import load_dotenv +from .mcp_server import get_mcp_tools, get_mcp_tools_for_gemin_api + +load_dotenv() + +# Initialize OpenAI client with Gemini API +client = OpenAI( + api_key=os.getenv("GEMINI_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", +) + +# Define the AI agent instructions +AGENT_INSTRUCTIONS = """ +You are TaskFlow AI, a premium productivity assistant. Your goal is to help users manage their life and work tasks with zero friction. + +You have access to several tools: +- `add_task`: Use this when the user wants to create a new task. You can optionally specify a `project_name` to group them. +- `list_tasks`: Use this to show the user their tasks. You can filter by status. +- `complete_task`: Use this to mark a task as done. +- `delete_task`: Use this to remove a task. +- `update_task`: Use this to change a task's title or description. +- `create_project`: Create a new project (collection of tasks) with a specific name, description, and color. +- `list_projects`: Show all existing projects. +- `get_calendar`: Retrieve tasks and events for a date range. "Calendar events" are simply tasks with a due date. + +Guidelines: +1. Always be professional, helpful, and concise. +2. If a user's request is vague, ask for clarification. +3. When listing tasks or calendar items, use a clean markdown format. Use bullet points or tables. +4. If a tool call fails, explain the issue politely to the user. +5. You can handle multiple tasks/projects in one go if the user asks for it. +6. Always confirm when an action has been successfully performed. +7. NEVER ask the user for their user_id. It is handled automatically by the system. +8. Today's date is 2025-12-29. Use this for relative date requests. +9. When the user asks for "calendar" or "schedule", use `get_calendar`. + +Your tone should be encouraging and efficient. Let's get things done! +""" + +def get_todo_agent(): + """ + Configure and return the AI agent with Gemini 2.0 Flash and MCP tools + """ + # Return a configuration object that can be used by the chat endpoint + return { + "client": client, + "model": "gemini-2.5-flash", + "instructions": AGENT_INSTRUCTIONS, + } + +# For now, return a simple configuration +todo_agent_config = get_todo_agent() \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..75dc27e2d934fce9498e73ed6b79d4b4f9c6f985 --- /dev/null +++ b/src/config.py @@ -0,0 +1,27 @@ +from pydantic_settings import BaseSettings +from typing import Optional + + +class Settings(BaseSettings): + # Database + DATABASE_URL: str = "postgresql://neondb_owner:npg_LsojKQF8bGn2@ep-mute-pine-a4g0wfsu-pooler.us-east-1.aws.neon.tech/neondb?sslmode=require&channel_binding=require" + # Auth + BETTER_AUTH_SECRET: str = "your-secret-key-change-in-production" + JWT_SECRET_KEY: str = "your-jwt-secret-change-in-production" + JWT_ALGORITHM: str = "HS256" + ACCESS_TOKEN_EXPIRE_DAYS: int = 7 + JWT_COOKIE_SECURE: bool = False # Set to True in production (requires HTTPS) + JWT_COOKIE_SAMESITE: str = "lax" # "lax" | "strict" | "none" (use "none" for cross-site cookies in production) + + # CORS + FRONTEND_URL: str = "https://victorious-mushroom-09538ac1e.2.azurestaticapps.net" + + # AI API Keys + GEMINI_API_KEY: Optional[str] = "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" + OPENAI_API_KEY: Optional[str] = "AIzaSyDcrSw3MIP0f4uJAf8Ol6M2BB4KUpkBRqI" + + class Config: + env_file = ".env" + + +settings = Settings() \ No newline at end of file diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef326a0de4da7d7dd3630ee3a8dfe13105f4e5f --- /dev/null +++ b/src/database.py @@ -0,0 +1,24 @@ +from sqlmodel import create_engine, Session +from contextlib import contextmanager +from .config import settings +# Create the database engine +engine = create_engine( + settings.DATABASE_URL, + echo=False, # Set to True for SQL query logging + pool_pre_ping=True, + pool_size=5, + max_overflow=10 +) + + +@contextmanager +def get_session(): + """Context manager for database sessions.""" + with Session(engine) as session: + yield session + + +def get_session_dep(): + """Dependency for FastAPI to get database session.""" + with get_session() as session: + yield session \ No newline at end of file diff --git a/src/events.py b/src/events.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5e8adf2d0cc11313978fcf362c46bb33972502 --- /dev/null +++ b/src/events.py @@ -0,0 +1,123 @@ +import httpx +import uuid +from datetime import datetime +from typing import Any, Dict, Optional +import asyncio +import logging +import time +from .models.task import Task # Assuming Task model exists +from .utils.circuit_breaker import kafka_circuit_breaker +from .utils.metrics import ( + increment_event_published, + observe_event_publish_duration, + increment_event_publish_error, + increment_rate_limiter_request, + increment_rate_limiter_rejection +) +from .utils.rate_limiter import event_publisher_rate_limiter + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Retry configuration +MAX_RETRIES = 3 +RETRY_DELAY = 1 # seconds + + +async def publish_task_event(event_type: str, task: Task): + """ + Publish a task event to Kafka via Dapr with retry mechanism, circuit breaker, and rate limiting. + Implements graceful degradation - operations continue even if event publishing fails. + + Args: + event_type: The type of event ('created', 'updated', 'completed', 'deleted') + task: The task object that triggered the event + """ + start_time = time.time() + + # Rate limiting check - use user_id as the rate limiting key + user_id = getattr(task, 'user_id', 'unknown') + rate_limit_key = f"event_publisher:{user_id}" + + increment_rate_limiter_request(rate_limit_key) + + if not event_publisher_rate_limiter.is_allowed(rate_limit_key): + logger.warning(f"Rate limit exceeded for user {user_id}, event type {event_type}") + increment_rate_limiter_rejection(rate_limit_key) + # Continue with the main operation but skip event publishing + logger.info(f"Skipping event publishing due to rate limit for user {user_id}") + return + + event = { + "event_id": str(uuid.uuid4()), + "event_type": event_type, + "timestamp": datetime.utcnow().isoformat() + "Z", + "user_id": str(user_id), # Convert to string for consistency + "task_id": getattr(task, 'id', 0), # Assuming id exists on task + "task_data": { + "title": getattr(task, 'title', ''), + "description": getattr(task, 'description', ''), + "completed": getattr(task, 'completed', False) + } + } + + # Use circuit breaker to wrap the publishing operation + async def _publish_with_retry(): + # Publish via Dapr Pub/Sub with retry mechanism + for attempt in range(MAX_RETRIES): + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:3500/v1.0/publish/kafka-pubsub/task-events", + json=event + ) + response.raise_for_status() + logger.info(f"Event published successfully: {event_type} for task {task.id} on attempt {attempt + 1}") + return # Success, exit the function + except httpx.RequestError as e: + logger.warning(f"Attempt {attempt + 1} failed to publish event: {e}") + if attempt == MAX_RETRIES - 1: # Last attempt + logger.error(f"Failed to publish event after {MAX_RETRIES} attempts: {e}") + raise # Re-raise the exception after all retries are exhausted + + # Wait before retrying (exponential backoff) + await asyncio.sleep(RETRY_DELAY * (2 ** attempt)) + + logger.error(f"All {MAX_RETRIES} attempts failed to publish event for task {task.id}") + raise Exception(f"Failed to publish event after {MAX_RETRIES} attempts") + + # Call the publishing function through the circuit breaker + # Use graceful degradation: if event publishing fails, log the error but don't fail the main operation + try: + await kafka_circuit_breaker.call(_publish_with_retry) + duration = time.time() - start_time + logger.info(f"Successfully published {event_type} event for task {task.id}") + increment_event_published(event_type) + observe_event_publish_duration(event_type, duration) + except Exception as e: + duration = time.time() - start_time + logger.error(f"Event publishing failed for task {task.id}, but main operation continues: {e}") + increment_event_publish_error(event_type) + observe_event_publish_duration(event_type, duration) + # Don't raise the exception - allow the main operation to continue (graceful degradation) + + +async def publish_created_event(task: Task): + """Publish a 'created' event for a new task.""" + await publish_task_event("created", task) + + +async def publish_updated_event(task: Task): + """Publish an 'updated' event for a modified task.""" + await publish_task_event("updated", task) + + +async def publish_deleted_event(task: Task): + """Publish a 'deleted' event for a deleted task.""" + await publish_task_event("deleted", task) + + +async def publish_completed_event(task: Task): + """Publish a 'completed' event for a completed task.""" + await publish_task_event("completed", task) \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000000000000000000000000000000000000..38f1464962d7030f30e8978123b5d807b597eb50 --- /dev/null +++ b/src/main.py @@ -0,0 +1,62 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from .config import settings + +from .routers import auth, tasks, projects, chat, audit +from .utils.health_check import kafka_health_checker + +app = FastAPI( + title="Task API", + description="Task management API with authentication", + version="1.0.0" +) + +# Include routers +app.include_router(auth.router) +app.include_router(tasks.router) +app.include_router(projects.router) +app.include_router(chat.router) +app.include_router(audit.router) + +# Prepare allowed origins from settings.FRONTEND_URL (comma separated) +_frontend_origins = [o.strip() for o in settings.FRONTEND_URL.split(",")] if settings.FRONTEND_URL else [] + +# CORS configuration (development and production) +# Include common development URLs and Minikube/K8s service URLs +allowed_origins = _frontend_origins + [ + # Local development (all common ports) + "http://localhost:3000", "http://localhost:3001", "http://localhost:8000", + "http://localhost:38905", "http://localhost:40529", # User's dynamic ports + "http://127.0.0.1:3000", "http://127.0.0.1:3001", "http://127.0.0.1:8000", + "http://127.0.0.1:38905", "http://127.0.0.1:40529", # User's dynamic ports + # Minikube NodePort (replace with your Minikube IP in production) + "http://192.168.49.2:30080", "http://192.168.49.2:30081", + "http://192.168.49.2:30147", "http://192.168.49.2:30148", + # Kubernetes internal service names (for cluster-internal communication) + "http://todo-chatbot-backend:8000", + "http://todo-chatbot-frontend:3000", 'https://ai-powered-full-stack-task-manageme.vercel.app', "https://victorious-mushroom-09538ac1e.2.azurestaticapps.net" +] + +app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + # Expose origin header for debugging if needed + expose_headers=["Access-Control-Allow-Origin"], +) + +@app.get("/api/health") +async def health_check(): + return {"status": "healthy"} + +@app.get("/api/health/kafka") +async def kafka_health_check(): + """Check Kafka connectivity through Dapr""" + health_result = await kafka_health_checker.check_kafka_connectivity() + return health_result + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/src/mcp_server.py b/src/mcp_server.py new file mode 100644 index 0000000000000000000000000000000000000000..1631bc8b9419348d1c391ebe471d9d7631ea9016 --- /dev/null +++ b/src/mcp_server.py @@ -0,0 +1,140 @@ +import asyncio +from fastmcp import FastMCP +from .mcp_tools.task_tools import get_task_tools, execute_add_task, execute_list_tasks, execute_complete_task, execute_delete_task, execute_update_task +from pydantic import BaseModel + + +def create_mcp_server(): + """Create and configure the MCP server with task tools""" + # Create the FastMCP server instance + mcp_server = FastMCP("task-mcp-server") + + # Add each tool to the server with its handler + @mcp_server.tool( + name="add_task", + description="Create a new task for the user", + input_schema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID to create task for"}, + "title": {"type": "string", "description": "Task title, 1-200 characters"}, + "description": {"type": "string", "description": "Task description, optional, max 1000 chars"} + }, + "required": ["user_id", "title"] + } + ) + async def handle_add_task(user_id: str, title: str, description: str = None): + from .mcp_tools.task_tools import AddTaskParams + params = AddTaskParams(user_id=user_id, title=title, description=description) + result = execute_add_task(params) + return result.dict() + + @mcp_server.tool( + name="list_tasks", + description="Retrieve user's tasks", + input_schema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID to list tasks for"}, + "status": {"type": "string", "enum": ["all", "pending", "completed"], "default": "all"} + }, + "required": ["user_id"] + } + ) + async def handle_list_tasks(user_id: str, status: str = "all"): + from .mcp_tools.task_tools import ListTasksParams + params = ListTasksParams(user_id=user_id, status=status) + result = execute_list_tasks(params) + return result.dict() + + @mcp_server.tool( + name="complete_task", + description="Mark a task as complete", + input_schema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID of the task owner"}, + "task_id": {"type": "integer", "description": "ID of the task to update"} + }, + "required": ["user_id", "task_id"] + } + ) + async def handle_complete_task(user_id: str, task_id: int): + from .mcp_tools.task_tools import CompleteTaskParams + params = CompleteTaskParams(user_id=user_id, task_id=task_id) + result = execute_complete_task(params) + return result.dict() + + @mcp_server.tool( + name="delete_task", + description="Remove a task", + input_schema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID of the task owner"}, + "task_id": {"type": "integer", "description": "ID of the task to delete"} + }, + "required": ["user_id", "task_id"] + } + ) + async def handle_delete_task(user_id: str, task_id: int): + from .mcp_tools.task_tools import DeleteTaskParams + params = DeleteTaskParams(user_id=user_id, task_id=task_id) + result = execute_delete_task(params) + return result.dict() + + @mcp_server.tool( + name="update_task", + description="Modify task details", + input_schema={ + "type": "object", + "properties": { + "user_id": {"type": "string", "description": "User ID of the task owner"}, + "task_id": {"type": "integer", "description": "ID of the task to update"}, + "title": {"type": "string", "description": "New task title"}, + "description": {"type": "string", "description": "New task description"} + }, + "required": ["user_id", "task_id"] + } + ) + async def handle_update_task(user_id: str, task_id: int, title: str = None, description: str = None): + from .mcp_tools.task_tools import UpdateTaskParams + params = UpdateTaskParams(user_id=user_id, task_id=task_id, title=title, description=description) + result = execute_update_task(params) + return result.dict() + + return mcp_server + + +# Global MCP server instance - create only when needed +_mcp_server_instance = None + +def get_mcp_server_instance(): + global _mcp_server_instance + if _mcp_server_instance is None: + _mcp_server_instance = create_mcp_server() + return _mcp_server_instance + + +def get_mcp_tools(): + """Get the list of MCP tools for registration with the agent""" + # Return the tool definitions directly rather than accessing server instance + from .mcp_tools.task_tools import get_task_tools + return get_task_tools() + +def get_mcp_tools_for_gemin_api(): + """Get the list of tools for Gemini API""" + # Return the tool definitions in Gemini API format + from .mcp_tools.task_tools import get_task_tools_for_gemin_api + return get_task_tools_for_gemin_api() + + +# Run the server if this file is executed directly +if __name__ == "__main__": + import sys + if len(sys.argv) > 1 and sys.argv[1] == "--stdio": + # Run the server using stdio transport + from fastmcp.stdio import run_stdio_server + run_stdio_server(get_mcp_server_instance()) + else: + print("Usage: python mcp_server.py --stdio") \ No newline at end of file diff --git a/src/mcp_tools/__init__.py b/src/mcp_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e253e1a50b19ece010341015561227d58a5c3fa --- /dev/null +++ b/src/mcp_tools/__init__.py @@ -0,0 +1,3 @@ +from .task_tools import get_task_tools + +__all__ = ["get_task_tools"] \ No newline at end of file diff --git a/src/mcp_tools/task_tools.py b/src/mcp_tools/task_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee6077886eddeb90ad33613e4389c3dd253e4e9 --- /dev/null +++ b/src/mcp_tools/task_tools.py @@ -0,0 +1,539 @@ +# No MCP import needed - this file only defines tool parameters and functions +from pydantic import BaseModel, Field +from typing import List, Optional +import json +from sqlmodel import Session, select +from uuid import UUID +import sys +import os + +# Add the src directory to the path so we can import our models +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from sqlmodel import Session, select +from typing import List, Optional +from datetime import datetime +import uuid +from ..models.task import Task, TaskCreate, TaskUpdate +from ..models.project import Project, ProjectCreate, ProjectUpdate +from ..models.user import User +from ..database import engine + + +class AddTaskParams(BaseModel): + user_id: str = Field(description="User ID to create task for") + title: str = Field(description="Task title, 1-200 characters") + description: Optional[str] = Field(default=None, description="Task description, optional, max 1000 chars") + due_date: Optional[str] = Field(default=None, description="Due date in ISO format (YYYY-MM-DD)") + project_name: Optional[str] = Field(default=None, description="Optional name of the project to associate this task with") + + +class AddTaskResult(BaseModel): + task_id: int + status: str + title: str + + +class ListTasksParams(BaseModel): + user_id: str = Field(description="User ID to list tasks for") + status: Optional[str] = Field(default="all", description="Task status filter: 'all', 'pending', 'completed'") + + +class ListTasksResultItem(BaseModel): + id: int + title: str + completed: bool + created_at: str + + +class ListTasksResult(BaseModel): + tasks: List[ListTasksResultItem] + + +class CompleteTaskParams(BaseModel): + user_id: str = Field(description="User ID of the task owner") + task_id: int = Field(description="ID of the task to complete") + + +class CompleteTaskResult(BaseModel): + task_id: int + status: str + title: str + + +class DeleteTaskParams(BaseModel): + user_id: str = Field(description="User ID of the task owner") + task_id: int = Field(description="ID of the task to delete") + + +class DeleteTaskResult(BaseModel): + task_id: int + status: str + title: str + + +class UpdateTaskParams(BaseModel): + user_id: str = Field(description="User ID of the task owner") + task_id: int = Field(description="ID of the task to update") + title: Optional[str] = Field(default=None, description="New task title") + description: Optional[str] = Field(default=None, description="New task description") + + +class UpdateTaskResult(BaseModel): + task_id: int + status: str + title: str + + +class CreateProjectParams(BaseModel): + user_id: str = Field(description="User ID to create project for") + name: str = Field(description="Project name") + description: Optional[str] = Field(default=None, description="Project description") + color: Optional[str] = Field(default="#3b82f6", description="Hex color code") + + +class CreateProjectResult(BaseModel): + project_id: str + status: str + name: str + + +class ListProjectsParams(BaseModel): + user_id: str = Field(description="User ID to list projects for") + + +class ProjectResultItem(BaseModel): + id: str + name: str + description: Optional[str] + color: Optional[str] + + +class ListProjectsResult(BaseModel): + projects: List[ProjectResultItem] + + +class GetCalendarParams(BaseModel): + user_id: str = Field(description="User ID") + start_date: str = Field(description="Start date ISO string") + end_date: str = Field(description="End date ISO string") + + +class CalendarItem(BaseModel): + id: int + title: str + due_date: str + completed: bool + project_name: Optional[str] + + +class GetCalendarResult(BaseModel): + items: List[CalendarItem] + + +def get_task_tools_for_gemin_api(): + """Returns a list of all task-related tools in Gemini API format, hiding user_id from AI""" + def get_schema_without_user_id(model): + schema = model.model_json_schema() + if "properties" in schema and "user_id" in schema["properties"]: + del schema["properties"]["user_id"] + if "required" in schema and "user_id" in schema["required"]: + schema["required"].remove("user_id") + return schema + + tools = [ + { + "type": "function", + "function": { + "name": "add_task", + "description": "Create a new task for the user. Do not ask for user_id.", + "parameters": get_schema_without_user_id(AddTaskParams) + } + }, + { + "type": "function", + "function": { + "name": "list_tasks", + "description": "Retrieve user's tasks. Do not ask for user_id.", + "parameters": get_schema_without_user_id(ListTasksParams) + } + }, + { + "type": "function", + "function": { + "name": "complete_task", + "description": "Mark a task as complete. Do not ask for user_id.", + "parameters": get_schema_without_user_id(CompleteTaskParams) + } + }, + { + "type": "function", + "function": { + "name": "delete_task", + "description": "Remove a task. Do not ask for user_id.", + "parameters": get_schema_without_user_id(DeleteTaskParams) + } + }, + { + "type": "function", + "function": { + "name": "update_task", + "description": "Modify task details. Do not ask for user_id.", + "parameters": get_schema_without_user_id(UpdateTaskParams) + } + }, + { + "type": "function", + "function": { + "name": "create_project", + "description": "Create a new project. Projects can hold multiple tasks. Do not ask for user_id.", + "parameters": get_schema_without_user_id(CreateProjectParams) + } + }, + { + "type": "function", + "function": { + "name": "list_projects", + "description": "List all projects for the user. Do not ask for user_id.", + "parameters": get_schema_without_user_id(ListProjectsParams) + } + }, + { + "type": "function", + "function": { + "name": "get_calendar", + "description": "Get tasks and events for a specific date range (Calendar view). Do not ask for user_id.", + "parameters": get_schema_without_user_id(GetCalendarParams) + } + } + ] + return tools + + +def get_task_tools(): + """Returns a list of all task-related MCP tools (for MCP usage)""" + tools = [ + { + "name": "add_task", + "description": "Create a new task for the user", + "input_schema": AddTaskParams.model_json_schema() + }, + { + "name": "list_tasks", + "description": "Retrieve user's tasks", + "input_schema": ListTasksParams.model_json_schema() + }, + { + "name": "complete_task", + "description": "Mark a task as complete", + "input_schema": CompleteTaskParams.model_json_schema() + }, + { + "name": "delete_task", + "description": "Remove a task", + "input_schema": DeleteTaskParams.model_json_schema() + }, + { + "name": "update_task", + "description": "Modify task details", + "input_schema": UpdateTaskParams.model_json_schema() + } + ] + return tools + + +def execute_add_task(params: AddTaskParams) -> AddTaskResult: + """Execute the add_task tool""" + try: + # Validate user_id format + try: + user_uuid = uuid.UUID(params.user_id) + except ValueError: + raise ValueError(f"Invalid user_id format: {params.user_id}") + + # Create a database session + with Session(engine) as db_session: + # Verify user exists + user_exists = db_session.exec( + select(User).where(User.id == user_uuid) + ).first() + + if not user_exists: + raise ValueError(f"User with id {params.user_id} not found") + + # Parse due date if provided + due_date_dt = None + if params.due_date: + try: + due_date_dt = datetime.fromisoformat(params.due_date.replace('Z', '+00:00')) + except ValueError: + # Try simple YYYY-MM-DD + try: + due_date_dt = datetime.strptime(params.due_date, "%Y-%m-%d") + except ValueError: + pass + + # Handle project association + project_id = None + if params.project_name: + project = db_session.exec( + select(Project).where( + Project.name == params.project_name, + Project.user_id == user_uuid + ) + ).first() + if project: + project_id = project.id + + # Create the task + task = Task( + title=params.title, + description=params.description, + due_date=due_date_dt, + user_id=user_uuid, + project_id=project_id + ) + + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + return AddTaskResult( + task_id=task.id, + status="created", + title=task.title + ) + except Exception as e: + raise e + + +def execute_list_tasks(params: ListTasksParams) -> ListTasksResult: + """Execute the list_tasks tool""" + try: + # Validate user_id format + try: + user_uuid = uuid.UUID(params.user_id) + except ValueError: + raise ValueError(f"Invalid user_id format: {params.user_id}") + + # Create a database session + with Session(engine) as db_session: + # Build query based on status filter + query = select(Task).where(Task.user_id == user_uuid) + + if params.status and params.status.lower() == "completed": + query = query.where(Task.completed == True) + elif params.status and params.status.lower() == "pending": + query = query.where(Task.completed == False) + + # Execute query + tasks = db_session.exec(query).all() + + # Convert to result format + task_items = [ + ListTasksResultItem( + id=task.id, + title=task.title, + completed=task.completed, + created_at=task.created_at.isoformat() if task.created_at else "" + ) + for task in tasks + ] + + return ListTasksResult(tasks=task_items) + except Exception as e: + raise e + + +def execute_complete_task(params: CompleteTaskParams) -> CompleteTaskResult: + """Execute the complete_task tool""" + try: + # Validate user_id format + try: + user_uuid = uuid.UUID(params.user_id) + except ValueError: + raise ValueError(f"Invalid user_id format: {params.user_id}") + + # Create a database session + with Session(engine) as db_session: + # Find the task and verify it belongs to the user + task = db_session.exec( + select(Task).where( + Task.id == params.task_id, + Task.user_id == user_uuid + ) + ).first() + + if not task: + raise ValueError(f"Task with id {params.task_id} not found for user {params.user_id}") + + # Update task as completed + task.completed = True + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + return CompleteTaskResult( + task_id=task.id, + status="completed", + title=task.title + ) + except Exception as e: + raise e + + +def execute_delete_task(params: DeleteTaskParams) -> DeleteTaskResult: + """Execute the delete_task tool""" + try: + # Validate user_id format + try: + user_uuid = uuid.UUID(params.user_id) + except ValueError: + raise ValueError(f"Invalid user_id format: {params.user_id}") + + # Create a database session + with Session(engine) as db_session: + # Find the task and verify it belongs to the user + task = db_session.exec( + select(Task).where( + Task.id == params.task_id, + Task.user_id == user_uuid + ) + ).first() + + if not task: + raise ValueError(f"Task with id {params.task_id} not found for user {params.user_id}") + + # Delete the task + db_session.delete(task) + db_session.commit() + + return DeleteTaskResult( + task_id=task.id, + status="deleted", + title=task.title + ) + except Exception as e: + raise e + + +def execute_update_task(params: UpdateTaskParams) -> UpdateTaskResult: + """Execute the update_task tool""" + try: + # Validate user_id format + try: + user_uuid = uuid.UUID(params.user_id) + except ValueError: + raise ValueError(f"Invalid user_id format: {params.user_id}") + + # Create a database session + with Session(engine) as db_session: + # Find the task and verify it belongs to the user + task = db_session.exec( + select(Task).where( + Task.id == params.task_id, + Task.user_id == user_uuid + ) + ).first() + + if not task: + raise ValueError(f"Task with id {params.task_id} not found for user {params.user_id}") + + # Update the task with provided parameters + if params.title is not None: + task.title = params.title + if params.description is not None: + task.description = params.description + + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + return UpdateTaskResult( + task_id=task.id, + status="updated", + title=task.title + ) + except Exception as e: + raise e + + +def execute_create_project(params: CreateProjectParams) -> CreateProjectResult: + """Execute the create_project tool""" + try: + user_uuid = uuid.UUID(params.user_id) + with Session(engine) as db_session: + project = Project( + name=params.name, + description=params.description, + color=params.color, + user_id=user_uuid + ) + db_session.add(project) + db_session.commit() + db_session.refresh(project) + return CreateProjectResult( + project_id=str(project.id), + status="created", + name=project.name + ) + except Exception as e: + raise e + + +def execute_list_projects(params: ListProjectsParams) -> ListProjectsResult: + """Execute the list_projects tool""" + try: + user_uuid = uuid.UUID(params.user_id) + with Session(engine) as db_session: + projects = db_session.exec( + select(Project).where(Project.user_id == user_uuid) + ).all() + + project_items = [ + ProjectResultItem( + id=str(p.id), + name=p.name, + description=p.description, + color=p.color + ) for p in projects + ] + return ListProjectsResult(projects=project_items) + except Exception as e: + raise e + + +def execute_get_calendar(params: GetCalendarParams) -> GetCalendarResult: + """Execute the get_calendar tool""" + try: + user_uuid = uuid.UUID(params.user_id) + start_dt = datetime.fromisoformat(params.start_date.replace('Z', '+00:00')) + end_dt = datetime.fromisoformat(params.end_date.replace('Z', '+00:00')) + + with Session(engine) as db_session: + query = select(Task).where( + Task.user_id == user_uuid, + Task.due_date >= start_dt, + Task.due_date <= end_dt + ) + tasks = db_session.exec(query).all() + + items = [] + for task in tasks: + p_name = None + if task.project_id: + p = db_session.exec(select(Project).where(Project.id == task.project_id)).first() + if p: + p_name = p.name + + items.append(CalendarItem( + id=task.id, + title=task.title, + due_date=task.due_date.isoformat() if task.due_date else "", + completed=task.completed, + project_name=p_name + )) + + return GetCalendarResult(items=items) + except Exception as e: + raise e \ No newline at end of file diff --git a/src/middleware/auth.py b/src/middleware/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..1c164d9cc188e6f571d44df055cbee81bb3abc03 --- /dev/null +++ b/src/middleware/auth.py @@ -0,0 +1,41 @@ +from fastapi import HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from typing import Optional +from sqlmodel import Session +import uuid + +from ..models.user import User +from ..utils.security import verify_user_id_from_token +from ..database import get_session_dep +from fastapi import Depends + + +# Security scheme for JWT +security = HTTPBearer() + + +async def verify_jwt_token( + credentials: HTTPAuthorizationCredentials = Depends(security), + session: Session = Depends(get_session_dep) +): + """Verify JWT token and return user_id if valid.""" + token = credentials.credentials + user_id = verify_user_id_from_token(token) + + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token or expired token.", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Get user from database to ensure they still exist + user = session.get(User, user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User no longer exists.", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return user_id \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c6afca10acb418b1ef9cce835b0e87215b0926 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,31 @@ +from .user import User, UserCreate, UserRead +from .task import Task, TaskCreate, TaskRead, TaskUpdate +from .project import Project, ProjectCreate, ProjectRead, ProjectUpdate +from .conversation import Conversation, ConversationCreate, ConversationRead, ConversationUpdate +from .message import Message, MessageCreate, MessageRead, MessageUpdate +from .audit_log import AuditLog, AuditLogCreate, AuditLogRead + +__all__ = [ + "User", + "UserCreate", + "UserRead", + "Task", + "TaskCreate", + "TaskRead", + "TaskUpdate", + "Project", + "ProjectCreate", + "ProjectRead", + "ProjectUpdate", + "Conversation", + "ConversationCreate", + "ConversationRead", + "ConversationUpdate", + "Message", + "MessageCreate", + "MessageRead", + "MessageUpdate", + "AuditLog", + "AuditLogCreate", + "AuditLogRead", +] \ No newline at end of file diff --git a/src/models/audit_log.py b/src/models/audit_log.py new file mode 100644 index 0000000000000000000000000000000000000000..aa584f890af48b3b7399e7b51f1903f88e551032 --- /dev/null +++ b/src/models/audit_log.py @@ -0,0 +1,40 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime +from sqlalchemy import Column, DateTime, JSON +import uuid + + +class AuditLogBase(SQLModel): + event_id: str = Field(index=True) # UUID for deduplication + event_type: str = Field(max_length=50) # created|updated|completed|deleted + user_id: str # String user identifier + task_id: int # Reference to the affected task + event_data: dict = Field(sa_column=Column(JSON)) # JSONB field for event data + timestamp: datetime = Field(sa_column=Column(DateTime(timezone=True)), default=datetime.utcnow) + + +class AuditLog(AuditLogBase, table=True): + """ + Persistent record of all task events for a user. + Contains id, event_id, event_type, user_id, task_id, event_data (JSONB), and timestamp. + """ + id: Optional[int] = Field(default=None, primary_key=True) + event_id: str = Field(index=True, unique=True) # Unique constraint for deduplication + event_type: str = Field(max_length=50) # created|updated|completed|deleted + user_id: str # String user identifier + task_id: int # Reference to the affected task + event_data: dict = Field(sa_column=Column(JSON)) # JSONB field for event data + timestamp: datetime = Field(sa_column=Column(DateTime(timezone=True)), default=datetime.utcnow) + + +class AuditLogCreate(AuditLogBase): + pass + + +class AuditLogRead(AuditLogBase): + id: int + timestamp: datetime + + class Config: + from_attributes = True \ No newline at end of file diff --git a/src/models/conversation.py b/src/models/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..470e83b2919cfb227078c69c3fcc1e70ddebcf22 --- /dev/null +++ b/src/models/conversation.py @@ -0,0 +1,40 @@ +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional, List +from datetime import datetime +from sqlalchemy import Column, DateTime +import uuid +from .user import User # Import User model for relationship + + +class ConversationBase(SQLModel): + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + + +class Conversation(ConversationBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow)) + updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)) + + # Relationship to user + owner: Optional["User"] = Relationship(back_populates="conversations") + + # Relationship to messages + messages: List["Message"] = Relationship(back_populates="conversation") + + +class ConversationCreate(ConversationBase): + pass + + +class ConversationRead(ConversationBase): + id: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class ConversationUpdate(SQLModel): + pass \ No newline at end of file diff --git a/src/models/message.py b/src/models/message.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b571be3d8f1c206e0f592bc0b889b1d57232d7 --- /dev/null +++ b/src/models/message.py @@ -0,0 +1,41 @@ +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional +from datetime import datetime +from sqlalchemy import Column, DateTime, Enum as SAEnum +import uuid +from .conversation import Conversation # Import Conversation model for relationship + + +class MessageBase(SQLModel): + conversation_id: int = Field(foreign_key="conversation.id", index=True) + user_id: uuid.UUID + role: str = Field(sa_column=Column("role", SAEnum("user", "assistant", name="message_role"))) + content: str + + +class Message(MessageBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + conversation_id: int = Field(foreign_key="conversation.id", index=True) + user_id: uuid.UUID + role: str = Field(sa_column=Column("role", SAEnum("user", "assistant", name="message_role"))) + content: str + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow)) + + # Relationship to conversation + conversation: Optional["Conversation"] = Relationship(back_populates="messages") + + +class MessageCreate(MessageBase): + pass + + +class MessageRead(MessageBase): + id: int + created_at: datetime + + class Config: + from_attributes = True + + +class MessageUpdate(SQLModel): + content: Optional[str] = None \ No newline at end of file diff --git a/src/models/project.py b/src/models/project.py new file mode 100644 index 0000000000000000000000000000000000000000..351e3ee501b1a9d133c9dfa0e7e9ce3cff512327 --- /dev/null +++ b/src/models/project.py @@ -0,0 +1,49 @@ +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional, List +import uuid +from datetime import datetime +from sqlalchemy import Column, DateTime + + +class ProjectBase(SQLModel): + name: str = Field(min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + color: Optional[str] = Field(default="#3b82f6", max_length=7) # Hex color code + + +class Project(ProjectBase, table=True): + id: Optional[uuid.UUID] = Field(default_factory=uuid.uuid4, primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + name: str = Field(min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + color: Optional[str] = Field(default="#3b82f6", max_length=7) + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow)) + updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)) + deadline: Optional[datetime] = None + + # Relationship to user + owner: Optional["User"] = Relationship(back_populates="projects") + + # Relationship to tasks + tasks: List["Task"] = Relationship(back_populates="project") + + +class ProjectCreate(ProjectBase): + pass + + +class ProjectRead(ProjectBase): + id: uuid.UUID + user_id: uuid.UUID + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class ProjectUpdate(SQLModel): + name: Optional[str] = Field(default=None, min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + color: Optional[str] = Field(default=None, max_length=7) + deadline: Optional[datetime] = None \ No newline at end of file diff --git a/src/models/task.py b/src/models/task.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc5e476ec5f01aa8354c1a4778438c43c36b333 --- /dev/null +++ b/src/models/task.py @@ -0,0 +1,57 @@ +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional +from datetime import datetime +import uuid +from sqlalchemy import Column, DateTime + + +class TaskBase(SQLModel): + title: str = Field(min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + completed: bool = Field(default=False) + due_date: Optional[datetime] = None + is_ai_generated: bool = Field(default=False) + + +class Task(TaskBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", index=True) + project_id: Optional[uuid.UUID] = Field(default=None, foreign_key="project.id", index=True) + title: str = Field(min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + completed: bool = Field(default=False) + due_date: Optional[datetime] = None + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow)) + updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)) + + # Relationship to user + owner: Optional["User"] = Relationship(back_populates="tasks") + + # Relationship to project + project: Optional["Project"] = Relationship(back_populates="tasks") + + +class TaskCreate(TaskBase): + project_id: Optional[uuid.UUID] = None + is_ai_generated: bool = False + + +class TaskRead(TaskBase): + id: int + user_id: uuid.UUID + project_id: Optional[uuid.UUID] = None + created_at: datetime + updated_at: datetime + is_ai_generated: bool = False + + class Config: + from_attributes = True + + +class TaskUpdate(SQLModel): + title: Optional[str] = Field(default=None, min_length=1, max_length=200) + description: Optional[str] = Field(default=None, max_length=1000) + completed: Optional[bool] = None + project_id: Optional[uuid.UUID] = None + due_date: Optional[datetime] = None + is_ai_generated: Optional[bool] = None \ No newline at end of file diff --git a/src/models/user.py b/src/models/user.py new file mode 100644 index 0000000000000000000000000000000000000000..4027ff9004bfadf1d624cd4df82ee834c46793cd --- /dev/null +++ b/src/models/user.py @@ -0,0 +1,36 @@ +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional, List +import uuid +from datetime import datetime +from sqlalchemy import Column, DateTime + + +class UserBase(SQLModel): + email: str = Field(unique=True, index=True, max_length=255) + + +class User(UserBase, table=True): + id: Optional[uuid.UUID] = Field(default_factory=uuid.uuid4, primary_key=True) + email: str = Field(unique=True, index=True, max_length=255) + password_hash: str = Field(max_length=255) + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow)) + updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)) + + # Relationship to tasks + tasks: List["Task"] = Relationship(back_populates="owner") + + # Relationship to projects + projects: List["Project"] = Relationship(back_populates="owner") + + # Relationship to conversations + conversations: List["Conversation"] = Relationship(back_populates="owner") + + +class UserCreate(UserBase): + password: str + + +class UserRead(UserBase): + id: uuid.UUID + created_at: datetime + updated_at: datetime \ No newline at end of file diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1863edbf520fb180bc3c0a61458ba574a5928b4a --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,3 @@ +from . import auth, tasks, projects, chat, audit + +__all__ = ["auth", "tasks", "projects", "chat", "audit"] \ No newline at end of file diff --git a/src/routers/audit.py b/src/routers/audit.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3906a643735ae91fdd7f3df0522f894a775836 --- /dev/null +++ b/src/routers/audit.py @@ -0,0 +1,125 @@ +from fastapi import APIRouter, HTTPException, Depends, status, Body +from sqlmodel import Session, select +from typing import List, Dict, Any +from uuid import UUID +import logging + +from ..models.audit_log import AuditLog, AuditLogCreate +from ..models.user import User +from ..database import get_session_dep +from ..utils.deps import get_current_user + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/audit", tags=["audit"]) + + +@router.post("/events") +async def receive_audit_event( + event: Dict[str, Any] = Body(...), + session: Session = Depends(get_session_dep) +): + """ + Receives audit events from Dapr Pub/Sub (Kafka) and saves them to the database. + This endpoint is called by the Dapr sidecar when events are published. + """ + try: + logger.info(f"Received audit event: {event}") + + # Extract event data + event_id = event.get("event_id") + event_type = event.get("event_type") + user_id = event.get("user_id") + task_id = event.get("task_id") + task_data = event.get("task_data", {}) + timestamp = event.get("timestamp") + + # Validate required fields + if not all([event_id, event_type, user_id, task_id]): + logger.warning(f"Missing required fields in event: {event}") + return {"status": "error", "message": "Missing required fields"} + + # Check if event already exists (deduplication) + existing = session.exec( + select(AuditLog).where(AuditLog.event_id == event_id) + ).first() + + if existing: + logger.info(f"Event {event_id} already exists, skipping") + return {"status": "skipped", "message": "Event already exists"} + + # Create audit log entry + audit_log = AuditLog( + event_id=event_id, + event_type=event_type, + user_id=user_id, + task_id=task_id, + event_data={ + "title": task_data.get("title", ""), + "description": task_data.get("description", ""), + "completed": task_data.get("completed", False) + } + ) + + session.add(audit_log) + session.commit() + session.refresh(audit_log) + + logger.info(f"Audit event {event_id} saved successfully") + return {"status": "success", "message": "Event saved", "id": audit_log.id} + + except Exception as e: + logger.error(f"Error saving audit event: {e}", exc_info=True) + session.rollback() + return {"status": "error", "message": str(e)} + + +@router.get("/events/{user_id}", response_model=dict) +async def get_user_audit_events( + user_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep), + offset: int = 0, + limit: int = 50 +): + """Get audit events for a specific user.""" + + # Verify that the user_id matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + # Query audit logs for the user + query = select(AuditLog).where(AuditLog.user_id == str(user_id)).order_by(AuditLog.timestamp.desc()) + + # Get total count + total_query = select(AuditLog).where(AuditLog.user_id == str(user_id)) + total_count = len(session.exec(total_query).all()) + + # Apply pagination + audit_logs = session.exec(query.offset(offset).limit(limit)).all() + + # Convert to dict + events = [ + { + "id": log.id, + "event_id": log.event_id, + "event_type": log.event_type, + "user_id": log.user_id, + "task_id": log.task_id, + "event_data": log.event_data, + "timestamp": log.timestamp.isoformat() if log.timestamp else None + } + for log in audit_logs + ] + + return { + "events": events, + "total": total_count, + "offset": offset, + "limit": limit + } diff --git a/src/routers/auth.py b/src/routers/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0c2acfea3b76cb505c54ea6046a0348280fa57 --- /dev/null +++ b/src/routers/auth.py @@ -0,0 +1,189 @@ +from fastapi import APIRouter, HTTPException, status, Depends, Response, Request +from sqlmodel import Session, select +from typing import Annotated +from datetime import datetime, timedelta +from uuid import uuid4 +import secrets + +from ..models.user import User, UserCreate, UserRead +from ..schemas.auth import RegisterRequest, RegisterResponse, LoginRequest, LoginResponse, ForgotPasswordRequest, ResetPasswordRequest +from ..utils.security import hash_password, create_access_token, verify_password +from ..utils.deps import get_current_user +from ..database import get_session_dep +from ..config import settings + + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + + +@router.post("/register", response_model=RegisterResponse, status_code=status.HTTP_201_CREATED) +def register(user_data: RegisterRequest, response: Response, session: Session = Depends(get_session_dep)): + """Register a new user with email and password.""" + + # Check if user already exists + existing_user = session.exec(select(User).where(User.email == user_data.email)).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="An account with this email already exists" + ) + + # Validate password length + if len(user_data.password) < 8: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must be at least 8 characters" + ) + + # Hash the password + password_hash = hash_password(user_data.password) + + # Create new user + user = User( + email=user_data.email, + password_hash=password_hash + ) + + session.add(user) + session.commit() + session.refresh(user) + + # Create access token + access_token = create_access_token(data={"sub": str(user.id)}) + + # Set the token as an httpOnly cookie + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=settings.JWT_COOKIE_SECURE, # True in production, False in development + samesite=settings.JWT_COOKIE_SAMESITE, + max_age=settings.ACCESS_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds + path="/" + ) + + # Return response + return RegisterResponse( + id=user.id, + email=user.email, + message="Account created successfully" + ) + + +@router.post("/login", response_model=LoginResponse) +def login(login_data: LoginRequest, response: Response, session: Session = Depends(get_session_dep)): + """Authenticate user with email and password, return JWT token.""" + + # Find user by email + user = session.exec(select(User).where(User.email == login_data.email)).first() + + if not user or not verify_password(login_data.password, user.password_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Create access token + access_token = create_access_token(data={"sub": str(user.id)}) + + # Set the token as an httpOnly cookie + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=settings.JWT_COOKIE_SECURE, # True in production, False in development + samesite=settings.JWT_COOKIE_SAMESITE, + max_age=settings.ACCESS_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds + path="/" + ) + + # Debug: Print the cookie being set + print(f"Setting cookie: access_token={access_token}") + print(f"Cookie attributes: httponly={True}, secure={settings.JWT_COOKIE_SECURE}, samesite={settings.JWT_COOKIE_SAMESITE}, max_age={settings.ACCESS_TOKEN_EXPIRE_DAYS * 24 * 60 * 60}") + + # Return response + return LoginResponse( + access_token=access_token, + token_type="bearer", + user=RegisterResponse( + id=user.id, + email=user.email, + message="Login successful" + ) + ) + + +@router.post("/logout") +def logout(request: Request, response: Response, current_user: User = Depends(get_current_user)): + """Logout user by clearing the access token cookie.""" + # Clear the access_token cookie + response.set_cookie( + key="access_token", + value="", + httponly=True, + secure=settings.JWT_COOKIE_SECURE, + samesite=settings.JWT_COOKIE_SAMESITE, + max_age=0, # Expire immediately + path="/" + ) + + return {"message": "Logged out successfully"} + + +@router.get("/me", response_model=RegisterResponse) +def get_current_user_profile(request: Request, current_user: User = Depends(get_current_user)): + """Get the current authenticated user's profile.""" + # Debug: Print the cookies received + print(f"Received cookies: {request.cookies}") + print(f"Access token cookie: {request.cookies.get('access_token')}") + + return RegisterResponse( + id=current_user.id, + email=current_user.email, + message="User profile retrieved successfully" + ) + + +@router.post("/forgot-password") +def forgot_password(forgot_data: ForgotPasswordRequest, session: Session = Depends(get_session_dep)): + """Initiate password reset process by verifying email exists.""" + # Check if user exists + user = session.exec(select(User).where(User.email == forgot_data.email)).first() + + if not user: + # For security reasons, we don't reveal if the email exists or not + return {"message": "If the email exists, a reset link would be sent"} + + # In a real implementation, we would send an email here + # But as per requirements, we're just simulating the process + return {"message": "If the email exists, a reset link would be sent"} + + +@router.post("/reset-password") +def reset_password(reset_data: ResetPasswordRequest, session: Session = Depends(get_session_dep)): + """Reset user password after verification.""" + # Check if user exists + user = session.exec(select(User).where(User.email == reset_data.email)).first() + + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + # Validate password length + if len(reset_data.new_password) < 8: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must be at least 8 characters" + ) + + # Hash the new password + user.password_hash = hash_password(reset_data.new_password) + + # Update the user + session.add(user) + session.commit() + + return {"message": "Password reset successfully"} \ No newline at end of file diff --git a/src/routers/chat.py b/src/routers/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..3fde30daca141d6987ecfc218647e7b3a16a8e82 --- /dev/null +++ b/src/routers/chat.py @@ -0,0 +1,203 @@ +from fastapi import APIRouter, HTTPException, status, Depends +from sqlmodel import Session +from typing import Optional +from uuid import UUID +from pydantic import BaseModel +import json +import logging + +from ..models.user import User +from ..models.conversation import Conversation +from ..models.message import Message +from ..database import get_session_dep +from ..utils.deps import get_current_user +from ..services.conversation_service import ConversationService +from ..agent_config import todo_agent_config +from ..mcp_server import get_mcp_tools_for_gemin_api +from ..mcp_tools.task_tools import ( + execute_add_task, + execute_list_tasks, + execute_complete_task, + execute_delete_task, + execute_update_task, + execute_create_project, + execute_list_projects, + execute_get_calendar, + AddTaskParams, + ListTasksParams, + CompleteTaskParams, + DeleteTaskParams, + UpdateTaskParams, + CreateProjectParams, + ListProjectsParams, + GetCalendarParams +) + +router = APIRouter(prefix="/api/{user_id}/chat", tags=["chat"]) + +logger = logging.getLogger(__name__) + +class ChatRequest(BaseModel): + conversation_id: Optional[int] = None + message: str + +class ChatResponse(BaseModel): + conversation_id: int + response: str + tool_calls: list = [] + +@router.post("/", response_model=ChatResponse) +def chat( + user_id: UUID, + chat_request: ChatRequest, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """ + Handle chat requests from users using AI assistant with tool calling. + """ + logger.info(f"Chat endpoint called with user_id: {user_id}, current_user.id: {current_user.id}") + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + logger.warning(f"User ID mismatch: path user_id={user_id}, auth user_id={current_user.id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Access denied" + ) + + # Get or create conversation + conversation_id = chat_request.conversation_id + if conversation_id is None: + conversation = Conversation(user_id=user_id) + session.add(conversation) + session.commit() + session.refresh(conversation) + conversation_id = conversation.id + else: + conversation = session.get(Conversation, conversation_id) + if not conversation or conversation.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Conversation not found" + ) + + # Store user message + user_message = Message( + conversation_id=conversation_id, + user_id=user_id, + role="user", + content=chat_request.message + ) + session.add(user_message) + session.commit() + + # Get conversation history (last 10 messages for context) + conversation_history = ConversationService.get_messages( + conversation_id=conversation_id, + user_id=user_id, + db_session=session, + limit=10 + ) + + history_for_agent = [] + for msg in conversation_history: + history_for_agent.append({ + "role": msg.role, + "content": msg.content + }) + + agent_config = todo_agent_config + tools = get_mcp_tools_for_gemin_api() + + messages = [ + {"role": "system", "content": agent_config["instructions"]}, + *history_for_agent, + {"role": "user", "content": chat_request.message} + ] + + try: + # Call the AI agent with tools + response = agent_config["client"].chat.completions.create( + model=agent_config["model"], + messages=messages, + tools=tools, + tool_choice="auto" + ) + + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + + # If there are tool calls, execute them + if tool_calls: + # Add assistant's tool call message to history + messages.append(response_message) + + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = json.loads(tool_call.function.arguments) + + # Force the user_id to be the current user's ID for security + function_args["user_id"] = str(user_id) + + logger.info(f"Executing tool: {function_name} with args: {function_args}") + + result = None + try: + if function_name == "add_task": + result = execute_add_task(AddTaskParams(**function_args)) + elif function_name == "list_tasks": + result = execute_list_tasks(ListTasksParams(**function_args)) + elif function_name == "complete_task": + result = execute_complete_task(CompleteTaskParams(**function_args)) + elif function_name == "delete_task": + result = execute_delete_task(DeleteTaskParams(**function_args)) + elif function_name == "update_task": + result = execute_update_task(UpdateTaskParams(**function_args)) + elif function_name == "create_project": + result = execute_create_project(CreateProjectParams(**function_args)) + elif function_name == "list_projects": + result = execute_list_projects(ListProjectsParams(**function_args)) + elif function_name == "get_calendar": + result = execute_get_calendar(GetCalendarParams(**function_args)) + + tool_result_content = json.dumps(result.dict() if result else {"error": "Unknown tool"}) + except Exception as e: + logger.error(f"Error executing tool {function_name}: {str(e)}") + tool_result_content = json.dumps({"error": str(e)}) + + messages.append({ + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": tool_result_content, + }) + + # Get final response from AI after tool results + second_response = agent_config["client"].chat.completions.create( + model=agent_config["model"], + messages=messages, + ) + ai_response = second_response.choices[0].message.content + else: + ai_response = response_message.content + + except Exception as e: + logger.error(f"Error in AI processing: {str(e)}") + ai_response = f"I encountered an error processing your request. Please try again later. (Error: {str(e)})" + + # Store assistant response + assistant_message = Message( + conversation_id=conversation_id, + user_id=user_id, + role="assistant", + content=ai_response + ) + session.add(assistant_message) + session.commit() + + return ChatResponse( + conversation_id=conversation_id, + response=ai_response, + tool_calls=[] # We already handled them + ) \ No newline at end of file diff --git a/src/routers/projects.py b/src/routers/projects.py new file mode 100644 index 0000000000000000000000000000000000000000..cf49633350b57428836796051622f3821d84fc15 --- /dev/null +++ b/src/routers/projects.py @@ -0,0 +1,259 @@ +from fastapi import APIRouter, HTTPException, status, Depends +from sqlmodel import Session, select, and_, func +from typing import List +from uuid import UUID +from datetime import datetime + +from ..models.user import User +from ..models.project import Project, ProjectCreate, ProjectUpdate, ProjectRead +from ..models.task import Task +from ..database import get_session_dep +from ..utils.deps import get_current_user + + +router = APIRouter(prefix="/api/{user_id}/projects", tags=["projects"]) + + +@router.get("/", response_model=List[ProjectRead]) +def list_projects( + user_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """List all projects for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Build the query with user_id filter + query = select(Project).where(Project.user_id == user_id) + + # Apply ordering (newest first) + query = query.order_by(Project.created_at.desc()) + + projects = session.exec(query).all() + return projects + + +@router.post("/", response_model=ProjectRead, status_code=status.HTTP_201_CREATED) +def create_project( + *, + user_id: UUID, + project_data: ProjectCreate, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Create a new project for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + # Create the project + project = Project( + name=project_data.name, + description=project_data.description, + color=project_data.color, + user_id=user_id + ) + + session.add(project) + session.commit() + session.refresh(project) + + return project + + +@router.get("/{project_id}", response_model=ProjectRead) +def get_project( + *, + user_id: UUID, + project_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Get a specific project by ID for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Fetch the project + project = session.get(Project, project_id) + + # Check if project exists and belongs to the user + if not project or project.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + return project + + +@router.put("/{project_id}", response_model=ProjectRead) +def update_project( + *, + user_id: UUID, + project_id: UUID, + project_data: ProjectUpdate, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Update an existing project for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Fetch the project + project = session.get(Project, project_id) + + # Check if project exists and belongs to the user + if not project or project.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Update the project + project_data_dict = project_data.dict(exclude_unset=True) + for key, value in project_data_dict.items(): + setattr(project, key, value) + + session.add(project) + session.commit() + session.refresh(project) + + return project + + +@router.delete("/{project_id}") +def delete_project( + *, + user_id: UUID, + project_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Delete a project for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Fetch the project + project = session.get(Project, project_id) + + # Check if project exists and belongs to the user + if not project or project.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Delete the project + session.delete(project) + session.commit() + + return {"message": "Project deleted successfully"} + + +@router.get("/{project_id}/tasks", response_model=List[Task]) +def list_project_tasks( + *, + user_id: UUID, + project_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """List all tasks for a specific project.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Fetch the project + project = session.get(Project, project_id) + + # Check if project exists and belongs to the user + if not project or project.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Build the query with project_id filter + query = select(Task).where(Task.project_id == project_id) + + # Apply ordering (newest first) + query = query.order_by(Task.created_at.desc()) + + tasks = session.exec(query).all() + return tasks + + +@router.get("/{project_id}/progress") +def get_project_progress( + *, + user_id: UUID, + project_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Get progress statistics for a specific project.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Fetch the project + project = session.get(Project, project_id) + + # Check if project exists and belongs to the user + if not project or project.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found" + ) + + # Get task counts + total_tasks_query = select(func.count()).where(Task.project_id == project_id) + completed_tasks_query = select(func.count()).where(and_(Task.project_id == project_id, Task.completed == True)) + + total_tasks = session.exec(total_tasks_query).first() + completed_tasks = session.exec(completed_tasks_query).first() + + # Calculate progress + progress = 0 + if total_tasks > 0: + progress = round((completed_tasks / total_tasks) * 100, 2) + + return { + "total_tasks": total_tasks, + "completed_tasks": completed_tasks, + "pending_tasks": total_tasks - completed_tasks, + "progress": progress + } \ No newline at end of file diff --git a/src/routers/tasks.py b/src/routers/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..db87413f31a683e792c65463a0755bfba74fda00 --- /dev/null +++ b/src/routers/tasks.py @@ -0,0 +1,637 @@ +from fastapi import APIRouter, HTTPException, status, Depends +from sqlmodel import Session, select, and_, func +from typing import List +from uuid import UUID +from datetime import datetime, timedelta, date +import logging +import uuid as uuid_lib + +from ..models.user import User +from ..models.task import Task, TaskCreate, TaskUpdate, TaskRead +from ..models.audit_log import AuditLog +from ..schemas.task import TaskListResponse +from ..database import get_session_dep +from ..utils.deps import get_current_user +from ..events import publish_created_event, publish_updated_event, publish_deleted_event, publish_completed_event + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def save_audit_event( + session: Session, + event_type: str, + task: Task, + user_id: UUID +): + """ + Directly save an audit event to the database. + This bypasses Kafka/Dapr for now and provides immediate persistence. + """ + try: + event_id = str(uuid_lib.uuid4()) + audit_log = AuditLog( + event_id=event_id, + event_type=event_type, + user_id=str(user_id), + task_id=task.id, + event_data={ + "title": task.title, + "description": task.description or "", + "completed": task.completed + } + ) + session.add(audit_log) + session.flush() # Flush to database without committing (parent transaction handles commit) + logger.info(f"Audit event {event_type} saved for task {task.id}") + except Exception as e: + logger.error(f"Failed to save audit event: {e}") + # Don't raise - continue execution even if audit save fails + + +router = APIRouter(prefix="/api/{user_id}/tasks", tags=["tasks"]) + + +@router.get("/stats") +def get_task_stats( + user_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Get advanced task statistics, streaks, and achievements.""" + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + tasks = session.exec(select(Task).where(Task.user_id == user_id)).all() + + total = len(tasks) + completed_tasks = [t for t in tasks if t.completed] + completed_count = len(completed_tasks) + pending_count = total - completed_count + completion_rate = round((completed_count / total * 100), 1) if total > 0 else 0 + + # Streak calculation + # Group completed tasks by day (using updated_at as completion time for now) + completed_dates = sorted(list(set([t.updated_at.date() for t in completed_tasks])), reverse=True) + + streak = 0 + if completed_dates: + today = datetime.utcnow().date() + yesterday = today - timedelta(days=1) + + # Check if the streak is still active (completed something today or yesterday) + if completed_dates[0] == today or completed_dates[0] == yesterday: + # We count the current active streak + streak = 1 + for i in range(len(completed_dates) - 1): + if completed_dates[i] - timedelta(days=1) == completed_dates[i+1]: + streak += 1 + else: + break + + # Achievements logic + achievements = [ + { + "id": "first_task", + "title": "First Step", + "description": "Complete your first task", + "unlocked": completed_count >= 1, + "icon": "Star", + "progress": 100 if completed_count >= 1 else 0 + }, + { + "id": "five_tasks", + "title": "High Five", + "description": "Complete 5 tasks", + "unlocked": completed_count >= 5, + "icon": "Zap", + "progress": min(100, int(completed_count / 5 * 100)) + }, + { + "id": "ten_tasks", + "title": "Task Master", + "description": "Complete 10 tasks", + "unlocked": completed_count >= 10, + "icon": "Trophy", + "progress": min(100, int(completed_count / 10 * 100)) + }, + { + "id": "streak_3", + "title": "Consistent", + "description": "3-day completion streak", + "unlocked": streak >= 3, + "icon": "Flame", + "progress": min(100, int(streak / 3 * 100)) + }, + { + "id": "streak_7", + "title": "Unstoppable", + "description": "7-day completion streak", + "unlocked": streak >= 7, + "icon": "Award", + "progress": min(100, int(streak / 7 * 100)) + } + ] + + # Productivity chart data (last 7 days) + chart_data = [] + for i in range(6, -1, -1): + day = (datetime.utcnow() - timedelta(days=i)).date() + count = len([t for t in completed_tasks if t.updated_at.date() == day]) + chart_data.append({ + "date": day.strftime("%a"), + "count": count, + "isToday": i == 0 + }) + + return { + "total": total, + "completed": completed_count, + "pending": pending_count, + "completionRate": completion_rate, + "streak": streak, + "achievements": achievements, + "chartData": chart_data + } + + +@router.get("/", response_model=TaskListResponse) +def list_tasks( + user_id: UUID, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep), + completed: bool = None, + offset: int = 0, + limit: int = 50 +): + """List all tasks for the authenticated user with optional filtering.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Build the query with user_id filter + query = select(Task).where(Task.user_id == user_id) + + # Apply completed filter if specified + if completed is not None: + query = query.where(Task.completed == completed) + + # Apply ordering (newest first) + query = query.order_by(Task.created_at.desc()) + + # Apply pagination + query = query.offset(offset).limit(limit) + + tasks = session.exec(query).all() + + # Get total count for pagination info + total_query = select(func.count()).select_from(Task).where(Task.user_id == user_id) + if completed is not None: + total_query = total_query.where(Task.completed == completed) + total = session.exec(total_query).one() + + # Convert to response format + task_responses = [] + for task in tasks: + task_dict = { + "id": task.id, + "user_id": str(task.user_id), + "title": task.title, + "description": task.description, + "completed": task.completed, + "due_date": task.due_date.isoformat() if task.due_date else None, + "project_id": str(task.project_id) if task.project_id else None, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat() + } + task_responses.append(task_dict) + + return TaskListResponse( + tasks=task_responses, + total=total, + offset=offset, + limit=limit + ) + + +@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED) +async def create_task( + user_id: UUID, + task_data: TaskCreate, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Create a new task for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + + # Validate title length + if len(task_data.title) < 1 or len(task_data.title) > 200: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Title must be between 1 and 200 characters" + ) + + # Validate description length if provided + if task_data.description and len(task_data.description) > 1000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Description must be 1000 characters or less" + ) + + # Create new task + task = Task( + title=task_data.title, + description=task_data.description, + completed=task_data.completed, + due_date=task_data.due_date, + project_id=task_data.project_id, + user_id=user_id + ) + + session.add(task) + session.commit() + session.refresh(task) + + # Publish created event + try: + await publish_created_event(task) + logger.info(f"Published created event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish created event for task {task.id}: {e}") + # Continue execution even if event publishing fails + + # Save audit event to database + save_audit_event(session, "created", task, user_id) + + return TaskRead( + id=task.id, + user_id=task.user_id, + title=task.title, + description=task.description, + completed=task.completed, + due_date=task.due_date, + project_id=task.project_id, + created_at=task.created_at, + updated_at=task.updated_at + ) + + +@router.get("/{task_id}", response_model=TaskRead) +def get_task( + user_id: UUID, + task_id: int, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Get a specific task by ID for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Get the task + task = session.get(Task, task_id) + + # Verify the task exists and belongs to the user + if not task or task.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + return TaskRead( + id=task.id, + user_id=task.user_id, + title=task.title, + description=task.description, + completed=task.completed, + due_date=task.due_date, + project_id=task.project_id, + created_at=task.created_at, + updated_at=task.updated_at + ) + + +@router.put("/{task_id}", response_model=TaskRead) +async def update_task( + user_id: UUID, + task_id: int, + task_data: TaskUpdate, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Update an existing task for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Get the task + task = session.get(Task, task_id) + + # Verify the task exists and belongs to the user + if not task or task.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Store original values for the event + original_completed = task.completed + + # Update fields if provided + if task_data.title is not None: + if len(task_data.title) < 1 or len(task_data.title) > 200: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Title must be between 1 and 200 characters" + ) + task.title = task_data.title + + if task_data.description is not None: + if len(task_data.description) > 1000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Description must be 1000 characters or less" + ) + task.description = task_data.description + + if task_data.completed is not None: + task.completed = task_data.completed + + if task_data.due_date is not None: + task.due_date = task_data.due_date + + if task_data.project_id is not None: + task.project_id = task_data.project_id + + # Update the timestamp + task.updated_at = datetime.utcnow() + + session.add(task) + session.commit() + session.refresh(task) + + # Publish updated event + try: + await publish_updated_event(task) + logger.info(f"Published updated event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish updated event for task {task.id}: {e}") + # Continue execution even if event publishing fails + + # Save audit event for update + save_audit_event(session, "updated", task, user_id) + + # If the task was marked as completed, publish a completed event + if original_completed != task.completed and task.completed: + try: + await publish_completed_event(task) + logger.info(f"Published completed event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish completed event for task {task.id}: {e}") + + # Save audit event for completion + save_audit_event(session, "completed", task, user_id) + + return TaskRead( + id=task.id, + user_id=task.user_id, + title=task.title, + description=task.description, + completed=task.completed, + due_date=task.due_date, + project_id=task.project_id, + created_at=task.created_at, + updated_at=task.updated_at + ) + + +@router.patch("/{task_id}", response_model=TaskRead) +async def patch_task( + user_id: UUID, + task_id: int, + task_data: TaskUpdate, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Partially update an existing task for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Get the task + task = session.get(Task, task_id) + + # Verify the task exists and belongs to the user + if not task or task.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Store original values for the event + original_completed = task.completed + + # Update fields if provided + if task_data.title is not None: + if len(task_data.title) < 1 or len(task_data.title) > 200: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Title must be between 1 and 200 characters" + ) + task.title = task_data.title + + if task_data.description is not None: + if len(task_data.description) > 1000: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Description must be 1000 characters or less" + ) + task.description = task_data.description + + if task_data.completed is not None: + task.completed = task_data.completed + + if task_data.due_date is not None: + task.due_date = task_data.due_date + + if task_data.project_id is not None: + task.project_id = task_data.project_id + + # Update the timestamp + task.updated_at = datetime.utcnow() + + session.add(task) + session.commit() + session.refresh(task) + + # Publish updated event + try: + await publish_updated_event(task) + logger.info(f"Published updated event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish updated event for task {task.id}: {e}") + + # Save audit event for update + save_audit_event(session, "updated", task, user_id) + + # If the task was marked as completed, publish a completed event + if original_completed != task.completed and task.completed: + try: + await publish_completed_event(task) + logger.info(f"Published completed event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish completed event for task {task.id}: {e}") + + # Save audit event for completion + save_audit_event(session, "completed", task, user_id) + + return TaskRead( + id=task.id, + user_id=task.user_id, + title=task.title, + description=task.description, + completed=task.completed, + due_date=task.due_date, + project_id=task.project_id, + created_at=task.created_at, + updated_at=task.updated_at + ) + + +@router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_task( + user_id: UUID, + task_id: int, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Delete a task for the authenticated user.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Get the task + task = session.get(Task, task_id) + + # Verify the task exists and belongs to the user + if not task or task.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Publish deleted event before deleting the task + try: + await publish_deleted_event(task) + logger.info(f"Published deleted event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish deleted event for task {task.id}: {e}") + # Continue with deletion even if event publishing fails + + # Save audit event for deletion (while task still exists) + save_audit_event(session, "deleted", task, user_id) + + session.delete(task) + session.commit() + + # Return 204 No Content + return + + +@router.patch("/{task_id}/toggle", response_model=TaskRead) +async def toggle_task_completion( + user_id: UUID, + task_id: int, + current_user: User = Depends(get_current_user), + session: Session = Depends(get_session_dep) +): + """Toggle the completion status of a task.""" + + # Verify that the user_id in the URL matches the authenticated user + if current_user.id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Get the task + task = session.get(Task, task_id) + + # Verify the task exists and belongs to the user + if not task or task.user_id != user_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + # Store original completion status for event + original_completed = task.completed + + # Toggle the completion status + task.completed = not task.completed + task.updated_at = datetime.utcnow() + + session.add(task) + session.commit() + session.refresh(task) + + # Publish updated event + try: + await publish_updated_event(task) + logger.info(f"Published updated event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish updated event for task {task.id}: {e}") + + # Save audit event for update + save_audit_event(session, "updated", task, user_id) + + # If the task was marked as completed, publish a completed event + if not original_completed and task.completed: + try: + await publish_completed_event(task) + logger.info(f"Published completed event for task {task.id}") + except Exception as e: + logger.error(f"Failed to publish completed event for task {task.id}: {e}") + + # Save audit event for completion + save_audit_event(session, "completed", task, user_id) + + return TaskRead( + id=task.id, + user_id=task.user_id, + title=task.title, + description=task.description, + completed=task.completed, + created_at=task.created_at, + updated_at=task.updated_at + ) \ No newline at end of file diff --git a/src/schemas/auth.py b/src/schemas/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..c221bbbb5a031871ae70c3cfee05477f7a3696db --- /dev/null +++ b/src/schemas/auth.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel, EmailStr +from typing import Optional +from datetime import datetime +from uuid import UUID + + +class RegisterRequest(BaseModel): + email: EmailStr + password: str + + +class RegisterResponse(BaseModel): + id: UUID + email: EmailStr + message: str + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class LoginResponse(BaseModel): + access_token: str + token_type: str + user: RegisterResponse + + +class ErrorResponse(BaseModel): + detail: str + status_code: Optional[int] = None + errors: Optional[list] = None + + +class ForgotPasswordRequest(BaseModel): + email: EmailStr + + +class ResetPasswordRequest(BaseModel): + email: EmailStr + new_password: str \ No newline at end of file diff --git a/src/schemas/task.py b/src/schemas/task.py new file mode 100644 index 0000000000000000000000000000000000000000..1a395fa211b352e5cf705bb23a5444530970cd39 --- /dev/null +++ b/src/schemas/task.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel +from typing import List, Optional +from datetime import datetime +from uuid import UUID + + +class TaskBase(BaseModel): + title: str + description: Optional[str] = None + completed: bool = False + due_date: Optional[datetime] = None + project_id: Optional[UUID] = None + + +class TaskCreate(TaskBase): + title: str + description: Optional[str] = None + + +class TaskUpdate(BaseModel): + title: Optional[str] = None + description: Optional[str] = None + completed: Optional[bool] = None + + +class TaskRead(TaskBase): + id: int + user_id: UUID + due_date: Optional[datetime] = None + project_id: Optional[UUID] = None + created_at: datetime + updated_at: datetime + + +class TaskListResponse(BaseModel): + tasks: List[TaskRead] + total: int + offset: int + limit: int \ No newline at end of file diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef54195943da87196951e07fa4885bcdf0f7e55 --- /dev/null +++ b/src/services/conversation_service.py @@ -0,0 +1,71 @@ +from sqlmodel import Session, select +from typing import List, Optional +from datetime import datetime +import uuid +from ..models.conversation import Conversation, ConversationCreate +from ..models.message import Message, MessageCreate + + +class ConversationService: + @staticmethod + def create_conversation(user_id: uuid.UUID, db_session: Session) -> Conversation: + """Create a new conversation for a user""" + conversation = Conversation(user_id=user_id) + db_session.add(conversation) + db_session.commit() + db_session.refresh(conversation) + return conversation + + @staticmethod + def get_conversation_by_id(conversation_id: int, user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]: + """Get a conversation by ID for a specific user (enforces user isolation)""" + statement = select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_id == user_id + ) + return db_session.exec(statement).first() + + @staticmethod + def get_messages(conversation_id: int, user_id: uuid.UUID, db_session: Session, limit: int = 20) -> List[Message]: + """Get messages from a conversation with user isolation enforced""" + # First verify the conversation belongs to the user + conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session) + if not conversation: + return [] + + # Get messages for this conversation + statement = select(Message).where( + Message.conversation_id == conversation_id + ).order_by(Message.created_at.desc()).limit(limit) + + messages = db_session.exec(statement).all() + # Reverse to return in chronological order (oldest first) + return list(reversed(messages)) + + @staticmethod + def add_message(conversation_id: int, user_id: uuid.UUID, role: str, content: str, db_session: Session) -> Message: + """Add a message to a conversation with user isolation enforced""" + # Verify the conversation belongs to the user + conversation = ConversationService.get_conversation_by_id(conversation_id, user_id, db_session) + if not conversation: + raise ValueError("Conversation not found or does not belong to user") + + message = Message( + conversation_id=conversation_id, + user_id=user_id, + role=role, + content=content + ) + db_session.add(message) + db_session.commit() + db_session.refresh(message) + return message + + @staticmethod + def get_latest_conversation(user_id: uuid.UUID, db_session: Session) -> Optional[Conversation]: + """Get the most recent conversation for a user""" + statement = select(Conversation).where( + Conversation.user_id == user_id + ).order_by(Conversation.created_at.desc()).limit(1) + + return db_session.exec(statement).first() \ No newline at end of file diff --git a/src/task_api.egg-info/PKG-INFO b/src/task_api.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..1c7c47c051182e9ac90b7ba5da1ed827aaee376e --- /dev/null +++ b/src/task_api.egg-info/PKG-INFO @@ -0,0 +1,16 @@ +Metadata-Version: 2.4 +Name: task-api +Version: 0.1.0 +Summary: Add your description here +Requires-Python: >=3.12 +Description-Content-Type: text/markdown +Requires-Dist: alembic>=1.17.2 +Requires-Dist: fastapi>=0.124.4 +Requires-Dist: passlib[bcrypt]>=1.7.4 +Requires-Dist: psycopg2-binary>=2.9.11 +Requires-Dist: pydantic-settings>=2.12.0 +Requires-Dist: pydantic[email]>=2.12.5 +Requires-Dist: python-jose[cryptography]>=3.5.0 +Requires-Dist: python-multipart>=0.0.20 +Requires-Dist: sqlmodel>=0.0.27 +Requires-Dist: uvicorn>=0.38.0 diff --git a/src/task_api.egg-info/SOURCES.txt b/src/task_api.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..462007061bce7601d7b8a22c46ac70f3d8439b26 --- /dev/null +++ b/src/task_api.egg-info/SOURCES.txt @@ -0,0 +1,21 @@ +README.md +pyproject.toml +src/__init__.py +src/config.py +src/database.py +src/main.py +src/middleware/auth.py +src/models/task.py +src/models/user.py +src/routers/__init__.py +src/routers/auth.py +src/routers/tasks.py +src/schemas/auth.py +src/schemas/task.py +src/task_api.egg-info/PKG-INFO +src/task_api.egg-info/SOURCES.txt +src/task_api.egg-info/dependency_links.txt +src/task_api.egg-info/requires.txt +src/task_api.egg-info/top_level.txt +src/utils/deps.py +src/utils/security.py \ No newline at end of file diff --git a/src/task_api.egg-info/dependency_links.txt b/src/task_api.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/task_api.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/task_api.egg-info/requires.txt b/src/task_api.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..e602cbe22e4dd2c7109353f698ee00d598a5a425 --- /dev/null +++ b/src/task_api.egg-info/requires.txt @@ -0,0 +1,10 @@ +alembic>=1.17.2 +fastapi>=0.124.4 +passlib[bcrypt]>=1.7.4 +psycopg2-binary>=2.9.11 +pydantic-settings>=2.12.0 +pydantic[email]>=2.12.5 +python-jose[cryptography]>=3.5.0 +python-multipart>=0.0.20 +sqlmodel>=0.0.27 +uvicorn>=0.38.0 diff --git a/src/task_api.egg-info/top_level.txt b/src/task_api.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c63ae26e92270f778cff08099c97ead89b451d2 --- /dev/null +++ b/src/task_api.egg-info/top_level.txt @@ -0,0 +1,9 @@ +__init__ +config +database +main +middleware +models +routers +schemas +utils diff --git a/src/utils/circuit_breaker.py b/src/utils/circuit_breaker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d1cb0bb65fc9a075176074d0df1d2df4717d08 --- /dev/null +++ b/src/utils/circuit_breaker.py @@ -0,0 +1,77 @@ +import asyncio +import time +from enum import Enum +from typing import Callable, Any, Awaitable +import logging + +logger = logging.getLogger(__name__) + + +class CircuitState(Enum): + CLOSED = "closed" # Normal operation + OPEN = "open" # Trip when failures exceed threshold + HALF_OPEN = "half_open" # Test if service recovered + + +class CircuitBreaker: + def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: tuple = (Exception,)): + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.expected_exception = expected_exception + + self.state = CircuitState.CLOSED + self.failure_count = 0 + self.last_failure_time = None + self._lock = asyncio.Lock() + + async def call(self, func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any: + async with self._lock: + if self.state == CircuitState.OPEN: + if time.time() - self.last_failure_time >= self.recovery_timeout: + self.state = CircuitState.HALF_OPEN + logger.info("Circuit breaker transitioning to HALF_OPEN") + else: + raise Exception("Circuit breaker is OPEN") + + if self.state == CircuitState.HALF_OPEN: + try: + result = await func(*args, **kwargs) + async with self._lock: + self.state = CircuitState.CLOSED + self.failure_count = 0 + logger.info("Circuit breaker closed after successful call") + return result + except self.expected_exception: + async with self._lock: + self.state = CircuitState.OPEN + self.failure_count = self.failure_threshold # Force open state + self.last_failure_time = time.time() + logger.warning("Circuit breaker opened after failed attempt in HALF_OPEN state") + raise + elif self.state == CircuitState.CLOSED: + try: + result = await func(*args, **kwargs) + async with self._lock: + # Reset failure count on success + self.failure_count = 0 + return result + except self.expected_exception as e: + async with self._lock: + self.failure_count += 1 + if self.failure_count >= self.failure_threshold: + self.state = CircuitState.OPEN + self.last_failure_time = time.time() + logger.warning(f"Circuit breaker opened after {self.failure_count} consecutive failures") + else: + logger.warning(f"Circuit breaker failure count: {self.failure_count}/{self.failure_threshold}") + raise e + else: + raise Exception("Circuit breaker state unknown") + + +# Global circuit breaker instance for Kafka connections +kafka_circuit_breaker = CircuitBreaker( + failure_threshold=3, + recovery_timeout=30.0, + expected_exception=(Exception,) +) \ No newline at end of file diff --git a/src/utils/deps.py b/src/utils/deps.py new file mode 100644 index 0000000000000000000000000000000000000000..6a770b27d0649d53cc4051442ecd101d216b4102 --- /dev/null +++ b/src/utils/deps.py @@ -0,0 +1,74 @@ +from fastapi import Depends, HTTPException, status, Request +from sqlmodel import Session +from typing import Generator +from ..database import get_session_dep +from ..models.user import User +from .security import verify_user_id_from_token +from uuid import UUID + + +def get_current_user( + request: Request, + session: Session = Depends(get_session_dep) +) -> User: + """Dependency to get the current authenticated user from JWT token in cookie or Authorization header.""" + # Debug: Print all cookies and headers + print(f"All cookies received: {request.cookies}") + print(f"All headers received: {request.headers}") + + # First try to get the token from the cookie + token = request.cookies.get("access_token") + print(f"Access token from cookie: {token}") + + # If no token in cookie, try to get it from Authorization header + if not token: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[7:] # Remove "Bearer " prefix + print(f"Access token from Authorization header: {token}") + + if not token: + print("No access token found in cookies or Authorization header") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + user_id = verify_user_id_from_token(token) + print(f"User ID from token: {user_id}") + + if not user_id: + print("Invalid user ID from token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + user = session.get(User, user_id) + print(f"User from database: {user}") + + if not user: + print("User not found in database") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return user + + +def get_user_by_id( + user_id: UUID, + session: Session = Depends(get_session_dep) +) -> User: + """Dependency to get a user by ID from the database.""" + user = session.get(User, user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) + return user \ No newline at end of file diff --git a/src/utils/health_check.py b/src/utils/health_check.py new file mode 100644 index 0000000000000000000000000000000000000000..91b37239e1ad4cbc729126097b5b08caf423fdc2 --- /dev/null +++ b/src/utils/health_check.py @@ -0,0 +1,103 @@ +import asyncio +import httpx +import logging +from typing import Dict, Any +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class KafkaHealthChecker: + def __init__(self, dapr_http_port: int = 3500): + self.dapr_http_port = dapr_http_port + self.last_check_time = None + self.last_status = None + + async def check_kafka_connectivity(self) -> Dict[str, Any]: + """ + Check if Dapr can reach Kafka by attempting to publish a test message. + + Returns: + Dictionary with health check results + """ + start_time = datetime.now() + + try: + # Try to ping Dapr sidecar first + async with httpx.AsyncClient(timeout=5.0) as client: + # Check if Dapr sidecar is running + dapr_health_url = f"http://localhost:{self.dapr_http_port}/v1.0/healthz" + dapr_response = await client.get(dapr_health_url) + + if dapr_response.status_code != 200: + logger.error(f"Dapr sidecar health check failed: {dapr_response.status_code}") + return { + "status": "unhealthy", + "details": f"Dapr sidecar not healthy: {dapr_response.status_code}", + "timestamp": start_time.isoformat(), + "response_time_ms": (datetime.now() - start_time).total_seconds() * 1000 + } + + # Try to get the list of pubsub components + components_url = f"http://localhost:{self.dapr_http_port}/v1.0/components" + components_response = await client.get(components_url) + + if components_response.status_code != 200: + logger.warning("Could not fetch Dapr components list") + # Continue anyway, as this might not indicate a Kafka problem + + # Try to publish a test message to verify Kafka connectivity + test_event = { + "event_id": "health-check-" + str(int(start_time.timestamp())), + "event_type": "health_check", + "timestamp": start_time.isoformat() + "Z", + "user_id": "system", + "task_id": 0, + "task_data": { + "title": "Health Check", + "description": "System health verification", + "completed": False + } + } + + # Attempt to publish to Kafka via Dapr (but don't actually process it) + # In a real scenario, we might use a test topic + response = await client.post( + f"http://localhost:{self.dapr_http_port}/v1.0/publish/kafka-pubsub/task-events", + json=test_event + ) + + # If we get here, assume Kafka is reachable (though the message may not be processed) + status = "healthy" + details = "Successfully connected to Kafka via Dapr" + + self.last_check_time = start_time + self.last_status = status + + logger.info("Kafka connectivity check: healthy") + + except httpx.TimeoutException: + status = "unhealthy" + details = "Timeout connecting to Dapr sidecar or Kafka" + logger.error("Kafka connectivity check: timeout") + + except httpx.RequestError as e: + status = "unhealthy" + details = f"Connection error: {str(e)}" + logger.error(f"Kafka connectivity check: connection error - {e}") + + except Exception as e: + status = "unhealthy" + details = f"Unexpected error: {str(e)}" + logger.error(f"Kafka connectivity check: unexpected error - {e}") + + return { + "status": status, + "details": details, + "timestamp": start_time.isoformat(), + "response_time_ms": (datetime.now() - start_time).total_seconds() * 1000 + } + + +# Global instance +kafka_health_checker = KafkaHealthChecker() \ No newline at end of file diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e232fa26d1416b61f8e6d5d39337dac66774d837 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,75 @@ +from prometheus_client import Counter, Histogram, Gauge +import time + +# Event publishing metrics +event_published_counter = Counter( + 'task_events_published_total', + 'Total number of task events published', + ['event_type'] +) + +event_publish_duration = Histogram( + 'task_event_publish_duration_seconds', + 'Time spent publishing task events', + ['event_type'] +) + +event_publish_errors = Counter( + 'task_event_publish_errors_total', + 'Total number of task event publishing errors', + ['event_type'] +) + +# Circuit breaker metrics +circuit_breaker_state = Gauge( + 'circuit_breaker_state', + 'Current state of the circuit breaker (0=closed, 1=open, 2=half-open)', + ['breaker_name'] +) + +circuit_breaker_failures = Counter( + 'circuit_breaker_failures_total', + 'Total number of circuit breaker failures', + ['breaker_name'] +) + +# Rate limiting metrics (for T048) +rate_limiter_requests = Counter( + 'rate_limiter_requests_total', + 'Total number of requests to rate limiter', + ['endpoint'] +) + +rate_limiter_rejections = Counter( + 'rate_limiter_rejections_total', + 'Total number of requests rejected by rate limiter', + ['endpoint'] +) + +def increment_event_published(event_type: str): + """Increment the counter for published events.""" + event_published_counter.labels(event_type=event_type).inc() + +def observe_event_publish_duration(event_type: str, duration: float): + """Record the duration of an event publishing operation.""" + event_publish_duration.labels(event_type=event_type).observe(duration) + +def increment_event_publish_error(event_type: str): + """Increment the counter for event publishing errors.""" + event_publish_errors.labels(event_type=event_type).inc() + +def set_circuit_breaker_state(breaker_name: str, state_value: int): + """Set the gauge value for circuit breaker state.""" + circuit_breaker_state.labels(breaker_name=breaker_name).set(state_value) + +def increment_circuit_breaker_failure(breaker_name: str): + """Increment the counter for circuit breaker failures.""" + circuit_breaker_failures.labels(breaker_name=breaker_name).inc() + +def increment_rate_limiter_request(endpoint: str): + """Increment the counter for rate limiter requests.""" + rate_limiter_requests.labels(endpoint=endpoint).inc() + +def increment_rate_limiter_rejection(endpoint: str): + """Increment the counter for rate limiter rejections.""" + rate_limiter_rejections.labels(endpoint=endpoint).inc() \ No newline at end of file diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..f8db6c295d53a8649bbc4682fa24cadd21eb5cef --- /dev/null +++ b/src/utils/rate_limiter.py @@ -0,0 +1,64 @@ +import time +from collections import defaultdict, deque +from typing import Dict +import threading + + +class RateLimiter: + def __init__(self, max_requests: int = 100, window_size: int = 60): + """ + Initialize rate limiter. + + Args: + max_requests: Maximum number of requests allowed per window + window_size: Time window in seconds + """ + self.max_requests = max_requests + self.window_size = window_size + self.requests = defaultdict(lambda: deque()) + self._lock = threading.RLock() + + def is_allowed(self, key: str) -> bool: + """ + Check if a request is allowed for the given key. + + Args: + key: Identifier for the client/user + + Returns: + True if request is allowed, False otherwise + """ + with self._lock: + current_time = time.time() + requests = self.requests[key] + + # Remove old requests outside the time window + while requests and current_time - requests[0] > self.window_size: + requests.popleft() + + # Check if we're under the limit + if len(requests) < self.max_requests: + requests.append(current_time) + return True + else: + return False + + def get_reset_time(self, key: str) -> float: + """ + Get the time when the rate limit will reset for the given key. + + Args: + key: Identifier for the client/user + + Returns: + Unix timestamp when the rate limit will reset + """ + with self._lock: + if key in self.requests and len(self.requests[key]) > 0: + oldest_request = self.requests[key][0] + return oldest_request + self.window_size + return time.time() + + +# Global rate limiter instance +event_publisher_rate_limiter = RateLimiter(max_requests=1000, window_size=60) # 1000 events per minute \ No newline at end of file diff --git a/src/utils/security.py b/src/utils/security.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb02be561c05b33350a08e5331fa0e53f889a3b --- /dev/null +++ b/src/utils/security.py @@ -0,0 +1,73 @@ +from passlib.context import CryptContext +from datetime import datetime, timedelta +from typing import Optional, Union +import uuid +import hashlib +from jose import JWTError, jwt +from ..config import settings + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def _prepare_password(password: str) -> str: + """ + Prepare password for bcrypt by hashing if longer than 72 bytes. + This avoids truncation and loss of information. + """ + password_bytes = password.encode('utf-8') + if len(password_bytes) > 72: + # Hash long passwords with SHA256 first to bring them within bcrypt's limit + password_hash = hashlib.sha256(password_bytes).hexdigest() + return password_hash + return password + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt, handling long passwords via SHA256 pre-hashing.""" + prepared = _prepare_password(password) + return pwd_context.hash(prepared) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a plain password against its hash, handling long passwords via SHA256 pre-hashing.""" + prepared = _prepare_password(plain_password) + return pwd_context.verify(prepared, hashed_password) + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """Create a JWT access token.""" + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + # Default to 7 days if no expiration is provided + expire = datetime.utcnow() + timedelta(days=settings.ACCESS_TOKEN_EXPIRE_DAYS) + + to_encode.update({"exp": expire, "iat": datetime.utcnow()}) + + encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + return encoded_jwt + + +def verify_token(token: str) -> Optional[dict]: + """Verify a JWT token and return the payload if valid.""" + try: + payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + return payload + except JWTError: + return None + + +def verify_user_id_from_token(token: str) -> Optional[uuid.UUID]: + """Extract user_id from JWT token.""" + payload = verify_token(token) + if payload: + user_id_str = payload.get("sub") + if user_id_str: + try: + return uuid.UUID(user_id_str) + except ValueError: + return None + return None \ No newline at end of file