MuzammilMax commited on
Commit
ed5cf78
·
verified ·
1 Parent(s): b908ed7

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements and install dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the application code
12
+ COPY . .
13
+
14
+ # Create a non-root user for security (Hugging Face Spaces often run as user 1000)
15
+ RUN useradd -m -u 1000 user
16
+ USER user
17
+ ENV PATH="/home/user/.local/bin:$PATH"
18
+
19
+ # Expose the port Hugging Face Spaces expects (7860)
20
+ EXPOSE 7860
21
+
22
+ # Command to run the application
23
+ # We use uvicorn to run the app
24
+ # host 0.0.0.0 is required for container networking
25
+ # port 7860 is required by Hugging Face Spaces
26
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
check_function_tool_type.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import agents
2
+ print(f"agents.function_tool is: {type(agents.function_tool)}")
3
+ print(f"agents.FunctionTool is: {type(agents.FunctionTool)}")
init_db.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import sys
3
+ from src.database import init_db
4
+ from src.models import User, Task, Session, Account, Verification, Conversation, Message # Ensure all models are registered
5
+ from src.database import engine
6
+ from sqlmodel import SQLModel
7
+
8
+ async def recreate_db():
9
+ print("Dropping existing tables...")
10
+ async with engine.begin() as conn:
11
+ await conn.run_sync(SQLModel.metadata.drop_all)
12
+ print("Initializing database...")
13
+ async with engine.begin() as conn:
14
+ await conn.run_sync(SQLModel.metadata.create_all)
15
+ print("Database initialized successfully with Better Auth tables.")
16
+
17
+ if __name__ == "__main__":
18
+ if sys.platform == 'win32':
19
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
20
+ asyncio.run(recreate_db())
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "mytodo-backend"
3
+ version = "0.1.0"
4
+ description = "MyTodoApp Backend"
5
+ requires-python = ">=3.10"
6
+
7
+ [tool.ruff]
8
+ line-length = 88
9
+ target-version = "py310"
10
+
11
+ [tool.black]
12
+ line-length = 88
13
+ target-version = ['py310']
14
+
15
+ [tool.pytest.ini_options]
16
+ asyncio_mode = "auto"
17
+ testpaths = ["tests"]
18
+ python_files = ["test_*.py"]
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.115.0
2
+ uvicorn[standard]>=0.30.0
3
+ sqlmodel>=0.0.22
4
+ psycopg[binary]>=3.2.0
5
+ asyncpg>=0.29.0
6
+ pydantic-settings>=2.5.0
7
+ python-jose[cryptography]>=3.3.0
8
+ passlib[bcrypt]>=1.7.4
9
+ python-multipart>=0.0.12
10
+ pytest>=8.0.0
11
+ pytest-asyncio>=0.24.0
12
+ httpx>=0.27.0
13
+ openai-agents
14
+ mcp
15
+ openai
16
+ sqlalchemy>=2.0.0
17
+ greenlet
18
+ python-dotenv
src/agents/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Agents module
src/agents/chatbot.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from typing import List, Optional, Dict, Any
4
+ from openai import AsyncOpenAI
5
+ from agents import Agent, Runner, set_default_openai_client, function_tool, ModelSettings, set_tracing_disabled, OpenAIChatCompletionsModel
6
+ from src.config import settings
7
+ from src.agents.tools import (
8
+ create_task_tool,
9
+ list_tasks_tool,
10
+ update_task_tool,
11
+ delete_task_tool
12
+ )
13
+
14
+ # Configure Gemini via OpenAI SDK (OpenRouter)
15
+ GEMINI_API_KEY = settings.GEMINI_API_KEY
16
+ client = None
17
+ if GEMINI_API_KEY:
18
+ # Disable tracing to avoid 401 errors from OpenAI-specific telemetry
19
+ set_tracing_disabled(True)
20
+
21
+ # Set default client for the entire process (using OpenRouter)
22
+ client = AsyncOpenAI(
23
+ api_key=GEMINI_API_KEY,
24
+ base_url="https://openrouter.ai/api/v1",
25
+ default_headers={
26
+ "HTTP-Referer": "http://localhost:3000",
27
+ "X-Title": "MyTodoApp",
28
+ }
29
+ )
30
+ set_default_openai_client(client)
31
+ # Set OPENAI_API_KEY to satisfy internal SDK requirements
32
+ os.environ["OPENAI_API_KEY"] = GEMINI_API_KEY
33
+
34
+ class TodoChatbot:
35
+ def __init__(self, db, user_id: str):
36
+ self.db = db
37
+ self.user_id = user_id
38
+ # Lock to ensure sequential database access per chatbot instance
39
+ self.lock = asyncio.Lock()
40
+
41
+ # Define local functions that wrap the tools with session, user context, and Lock
42
+ async def create_task(title: str, description: Optional[str] = None, status: str = "TODO", priority: str = "MEDIUM") -> str:
43
+ """
44
+ Create a new task for the user.
45
+ Args:
46
+ title: The title of the task.
47
+ description: A detailed description of the task.
48
+ status: The status of the task (TODO, IN_PROGRESS, DONE).
49
+ priority: The priority of the task (LOW, MEDIUM, HIGH).
50
+ """
51
+ async with self.lock:
52
+ return await create_task_tool(self.db, self.user_id, title, description, status, priority)
53
+
54
+ async def list_tasks(status: Optional[str] = None) -> str:
55
+ """
56
+ List tasks for the user.
57
+ Args:
58
+ status: Optional filter by status (TODO, IN_PROGRESS, DONE).
59
+ """
60
+ async with self.lock:
61
+ return await list_tasks_tool(self.db, self.user_id, status)
62
+
63
+ async def update_task(task_id: str, title: Optional[str] = None, description: Optional[str] = None, status: Optional[str] = None, priority: Optional[str] = None) -> str:
64
+ """
65
+ Update an existing task.
66
+ Args:
67
+ task_id: The UUID of the task to update.
68
+ title: New title for the task.
69
+ description: New description for the task.
70
+ status: New status for the task.
71
+ priority: New priority for the task.
72
+ """
73
+ async with self.lock:
74
+ return await update_task_tool(self.db, self.user_id, task_id, title, description, status, priority)
75
+
76
+ async def delete_task(task_id: str) -> str:
77
+ """
78
+ Delete a task.
79
+ Args:
80
+ task_id: The UUID of the task to delete.
81
+ """
82
+ async with self.lock:
83
+ return await delete_task_tool(self.db, self.user_id, task_id)
84
+
85
+ # Wrap OpenRouter model to bypass SDK prefix validation
86
+ gemini_model = OpenAIChatCompletionsModel(
87
+ model="google/gemini-2.0-flash-001",
88
+ openai_client=client
89
+ ) if client else "gpt-4o"
90
+
91
+ self.agent = Agent(
92
+ name="TodoBot",
93
+ model=gemini_model,
94
+ instructions=(
95
+ "You are an extremely concise Todo List assistant. "
96
+ "Actions: Create, List, Update (Mark Done), Delete. "
97
+ "Instructions: "
98
+ "1. BE BRIEF. No conversational filler like 'Sure, I can help'. "
99
+ "2. When asked to update/complete/delete a task, ALWAYS call 'list_tasks' first to find the ID. "
100
+ "3. To mark a task as DONE, use 'update_task' with status='DONE'. "
101
+ "4. Confirm actions with a single short sentence: 'Done: Task marked complete.' or 'Task created.' "
102
+ "5. Never ask follow-up questions unless the request is truly ambiguous."
103
+ ),
104
+ tools=[
105
+ function_tool(create_task),
106
+ function_tool(list_tasks),
107
+ function_tool(update_task),
108
+ function_tool(delete_task)
109
+ ],
110
+ model_settings=ModelSettings(temperature=0.1, max_tokens=1024, parallel_tool_calls=False)
111
+ )
112
+
113
+ async def get_response(self, messages: List[Dict[str, Any]]) -> Any:
114
+ """
115
+ Process a list of messages and return the agent's response.
116
+ """
117
+ runner = Runner()
118
+ result = await runner.run(self.agent, messages)
119
+ return result
src/agents/tools.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from typing import List, Optional, Any
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from src.models import Task, TaskStatus, TaskPriority
7
+
8
+ class ToolResult:
9
+ def __init__(self, success: bool, data: Any = None, error: str = None):
10
+ self.success = success
11
+ self.data = data
12
+ self.error = error
13
+
14
+ def to_json(self) -> str:
15
+ return json.dumps({
16
+ "success": self.success,
17
+ "data": self.data,
18
+ "error": self.error
19
+ })
20
+
21
+ # Base MCP Tool Wrapper
22
+ # In a real MCP setup, these would be registered with an MCP server.
23
+ # For this implementation, they will be used by the OpenAI Agents SDK.
24
+
25
+ def format_task(task: Task) -> dict:
26
+ return {
27
+ "id": str(task.id),
28
+ "title": task.title,
29
+ "description": task.description,
30
+ "status": task.status,
31
+ "priority": task.priority,
32
+ "created_at": task.created_at.isoformat()
33
+ }
34
+
35
+ async def create_task_tool(db: AsyncSession, user_id: str, title: str, description: Optional[str] = None, status: str = "TODO", priority: str = "MEDIUM") -> str:
36
+
37
+ """Create a new task for the authenticated user."""
38
+
39
+ print(f"Tool [create_task]: user={user_id} title='{title}' status={status}")
40
+
41
+ # Explicitly convert strings to Enums
42
+
43
+ task_status = TaskStatus(status.upper()) if isinstance(status, str) else status
44
+
45
+ task_priority = TaskPriority(priority.upper()) if isinstance(priority, str) else priority
46
+
47
+
48
+
49
+ task = Task(title=title, description=description, status=task_status, priority=task_priority, owner_id=user_id)
50
+
51
+ db.add(task)
52
+
53
+ await db.commit()
54
+
55
+ await db.refresh(task)
56
+
57
+ return json.dumps(format_task(task))
58
+
59
+
60
+
61
+ async def list_tasks_tool(db: AsyncSession, user_id: str, status: Optional[str] = None) -> str:
62
+
63
+ """List all tasks for the authenticated user, optionally filtered by status."""
64
+
65
+ print(f"Tool [list_tasks]: user={user_id} status_filter={status}")
66
+
67
+ statement = select(Task).where(Task.owner_id == user_id)
68
+
69
+ if status:
70
+
71
+ task_status = TaskStatus(status.upper()) if isinstance(status, str) else status
72
+
73
+ statement = statement.where(Task.status == task_status)
74
+
75
+ result = await db.execute(statement)
76
+
77
+ tasks = result.scalars().all()
78
+
79
+ return json.dumps([format_task(t) for t in tasks])
80
+
81
+
82
+
83
+ async def update_task_tool(db: AsyncSession, user_id: str, task_id: str, title: Optional[str] = None, description: Optional[str] = None, status: Optional[str] = None, priority: Optional[str] = None) -> str:
84
+
85
+ """Update an existing task for the authenticated user."""
86
+
87
+ print(f"Tool [update_task]: id={task_id} status_update={status}")
88
+
89
+ try:
90
+
91
+ task_uuid = uuid.UUID(task_id)
92
+
93
+ except ValueError:
94
+
95
+ return json.dumps({"error": "Invalid task ID format"})
96
+
97
+
98
+
99
+ statement = select(Task).where(Task.id == task_uuid, Task.owner_id == user_id)
100
+
101
+ result = await db.execute(statement)
102
+
103
+ task = result.scalars().first()
104
+
105
+ if not task:
106
+
107
+ print(f"Tool [update_task]: Task {task_id} NOT FOUND for user {user_id}")
108
+
109
+ return json.dumps({"error": "Task not found or access denied"})
110
+
111
+
112
+
113
+ if title is not None:
114
+
115
+ task.title = title
116
+
117
+ if description is not None:
118
+
119
+ task.description = description
120
+
121
+ if status is not None:
122
+
123
+ task.status = TaskStatus(status.upper()) if isinstance(status, str) else status
124
+
125
+ if priority is not None:
126
+
127
+ task.priority = TaskPriority(priority.upper()) if isinstance(priority, str) else priority
128
+
129
+
130
+
131
+ db.add(task)
132
+
133
+ await db.commit()
134
+
135
+ await db.refresh(task)
136
+
137
+ print(f"Tool [update_task]: Task {task_id} UPDATED successfully. New status: {task.status}")
138
+
139
+ return json.dumps(format_task(task))
140
+
141
+
142
+
143
+ async def delete_task_tool(db: AsyncSession, user_id: str, task_id: str) -> str:
144
+
145
+ """Delete a task for the authenticated user."""
146
+
147
+ print(f"Tool [delete_task]: id={task_id}")
src/auth.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from fastapi import Depends, HTTPException, status
3
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.future import select
6
+ from src.database import get_session
7
+ from src.models import User, Session
8
+
9
+ security = HTTPBearer()
10
+
11
+ async def get_current_user(auth: HTTPAuthorizationCredentials = Depends(security), session: AsyncSession = Depends(get_session)) -> User:
12
+ token = auth.credentials
13
+
14
+ # Query Better Auth session table
15
+ stmt = select(Session).where(Session.token == token)
16
+ result = await session.execute(stmt)
17
+ db_session = result.scalars().first()
18
+
19
+ if not db_session:
20
+ raise HTTPException(
21
+ status_code=status.HTTP_401_UNAUTHORIZED,
22
+ detail="Invalid session token",
23
+ headers={"WWW-Authenticate": "Bearer"},
24
+ )
25
+
26
+ if db_session.expiresAt < datetime.utcnow():
27
+ raise HTTPException(
28
+ status_code=status.HTTP_401_UNAUTHORIZED,
29
+ detail="Session expired",
30
+ headers={"WWW-Authenticate": "Bearer"},
31
+ )
32
+
33
+ user = await session.get(User, db_session.userId)
34
+ if user is None:
35
+ raise HTTPException(
36
+ status_code=status.HTTP_401_UNAUTHORIZED,
37
+ detail="User not found",
38
+ headers={"WWW-Authenticate": "Bearer"},
39
+ )
40
+ return user
src/config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from typing import Optional
3
+
4
+ class Settings(BaseSettings):
5
+ DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost/mytodo_db"
6
+ SECRET_KEY: str = "your-secret-key-must-be-changed-in-production"
7
+ ALGORITHM: str = "HS256"
8
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
9
+ GEMINI_API_KEY: Optional[str] = None
10
+
11
+ class Config:
12
+ env_file = ".env"
13
+ extra = "ignore"
14
+
15
+ settings = Settings()
src/database.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlmodel import SQLModel, create_engine
2
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
3
+ from sqlalchemy.orm import sessionmaker
4
+ from src.config import settings
5
+
6
+ DATABASE_URL = settings.DATABASE_URL
7
+
8
+ # Pass SSL mode explicitly for asyncpg
9
+ engine = create_async_engine(
10
+ DATABASE_URL,
11
+ echo=True,
12
+ future=True,
13
+ connect_args={"ssl": "require"}
14
+ )
15
+
16
+ async def get_session() -> AsyncSession:
17
+ async_session = sessionmaker(
18
+ engine, class_=AsyncSession, expire_on_commit=False
19
+ )
20
+ async with async_session() as session:
21
+ yield session
22
+
23
+ async def init_db():
24
+ async with engine.begin() as conn:
25
+ # await conn.run_sync(SQLModel.metadata.drop_all)
26
+ await conn.run_sync(SQLModel.metadata.create_all)
src/main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from src.config import settings
7
+ from src.routes import tasks, chat
8
+
9
+ # Set event loop policy for Windows to support asyncpg/SQLAlchemy correctly
10
+ if os.name == 'nt':
11
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
12
+
13
+ app = FastAPI(title="MyTodoApp Backend", version="0.1.0")
14
+
15
+ # CORS
16
+ # In production, set BACKEND_CORS_ORIGINS="https://your-frontend.vercel.app"
17
+ origins = settings.BACKEND_CORS_ORIGINS.split(",") if hasattr(settings, "BACKEND_CORS_ORIGINS") and settings.BACKEND_CORS_ORIGINS else [
18
+ "http://localhost:3000",
19
+ "http://127.0.0.1:3000",
20
+ "*" # Allow all for hackathon/testing ease - REMOVE in strict production
21
+ ]
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=origins,
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ app.include_router(tasks.router)
32
+ app.include_router(chat.router)
33
+
34
+ @app.exception_handler(Exception)
35
+ async def global_exception_handler(request: Request, exc: Exception):
36
+ # Print exception for debugging
37
+ print(f"Global exception: {exc}")
38
+ return JSONResponse(
39
+ status_code=500,
40
+ content={"detail": str(exc)},
41
+ )
42
+
43
+ @app.get("/health")
44
+ async def health_check():
45
+ return {"status": "ok"}
src/models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from datetime import datetime
3
+ from typing import Optional, List
4
+ from enum import Enum
5
+ from sqlmodel import Field, SQLModel, Relationship
6
+
7
+ class TaskStatus(str, Enum):
8
+ TODO = "TODO"
9
+ IN_PROGRESS = "IN_PROGRESS"
10
+ DONE = "DONE"
11
+
12
+ class TaskPriority(str, Enum):
13
+ LOW = "LOW"
14
+ MEDIUM = "MEDIUM"
15
+ HIGH = "HIGH"
16
+
17
+ # Better Auth Compatible Models
18
+ class UserBase(SQLModel):
19
+ email: str = Field(unique=True, index=True)
20
+ name: Optional[str] = None
21
+ image: Optional[str] = None
22
+
23
+ class User(UserBase, table=True):
24
+ id: str = Field(primary_key=True)
25
+ emailVerified: bool = Field(default=False)
26
+ createdAt: datetime = Field(default_factory=datetime.utcnow)
27
+ updatedAt: datetime = Field(default_factory=datetime.utcnow)
28
+
29
+ tasks: List["Task"] = Relationship(back_populates="owner")
30
+ sessions: List["Session"] = Relationship(back_populates="user")
31
+ conversations: List["Conversation"] = Relationship(back_populates="user")
32
+
33
+ class UserCreate(UserBase):
34
+ id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()))
35
+ password: str
36
+
37
+ class UserRead(UserBase):
38
+ id: str
39
+ createdAt: datetime
40
+ updatedAt: datetime
41
+
42
+ class Session(SQLModel, table=True):
43
+ id: str = Field(primary_key=True)
44
+ expiresAt: datetime
45
+ token: str
46
+ createdAt: datetime = Field(default_factory=datetime.utcnow)
47
+ updatedAt: datetime = Field(default_factory=datetime.utcnow)
48
+ ipAddress: Optional[str] = None
49
+ userAgent: Optional[str] = None
50
+ userId: str = Field(foreign_key="user.id")
51
+
52
+ user: User = Relationship(back_populates="sessions")
53
+
54
+ class Account(SQLModel, table=True):
55
+ id: str = Field(primary_key=True)
56
+ accountId: str
57
+ providerId: str
58
+ userId: str = Field(foreign_key="user.id")
59
+ accessToken: Optional[str] = None
60
+ refreshToken: Optional[str] = None
61
+ idToken: Optional[str] = None
62
+ accessTokenExpiresAt: Optional[datetime] = None
63
+ refreshTokenExpiresAt: Optional[datetime] = None
64
+ scope: Optional[str] = None
65
+ password: Optional[str] = None
66
+ createdAt: datetime = Field(default_factory=datetime.utcnow)
67
+ updatedAt: datetime = Field(default_factory=datetime.utcnow)
68
+
69
+ class Verification(SQLModel, table=True):
70
+ id: str = Field(primary_key=True)
71
+ identifier: str
72
+ value: str
73
+ expiresAt: datetime
74
+ createdAt: datetime = Field(default_factory=datetime.utcnow)
75
+ updatedAt: datetime = Field(default_factory=datetime.utcnow)
76
+
77
+ class TaskBase(SQLModel):
78
+ title: str = Field(index=True)
79
+ description: Optional[str] = None
80
+ status: TaskStatus = Field(default=TaskStatus.TODO)
81
+ priority: Optional[TaskPriority] = Field(default=TaskPriority.MEDIUM)
82
+
83
+ class Task(TaskBase, table=True):
84
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
85
+ created_at: datetime = Field(default_factory=datetime.utcnow)
86
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
87
+ owner_id: str = Field(foreign_key="user.id")
88
+
89
+ owner: User = Relationship(back_populates="tasks")
90
+
91
+ class TaskCreate(TaskBase):
92
+ pass
93
+
94
+ class TaskUpdate(SQLModel):
95
+ title: Optional[str] = None
96
+ description: Optional[str] = None
97
+ status: Optional[TaskStatus] = None
98
+ priority: Optional[TaskPriority] = None
99
+
100
+ class TaskRead(TaskBase):
101
+ id: uuid.UUID
102
+ created_at: datetime
103
+ updated_at: datetime
104
+ owner_id: str
105
+
106
+ class MessageRole(str, Enum):
107
+ USER = "user"
108
+ ASSISTANT = "assistant"
109
+ SYSTEM = "system"
110
+ TOOL = "tool"
111
+
112
+ class Conversation(SQLModel, table=True):
113
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
114
+ user_id: str = Field(foreign_key="user.id")
115
+ title: Optional[str] = None
116
+ created_at: datetime = Field(default_factory=datetime.utcnow)
117
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
118
+
119
+ user: User = Relationship(back_populates="conversations")
120
+ messages: List["Message"] = Relationship(back_populates="conversation")
121
+
122
+ class Message(SQLModel, table=True):
123
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
124
+ conversation_id: uuid.UUID = Field(foreign_key="conversation.id")
125
+ role: MessageRole
126
+ content: Optional[str] = None
127
+ tool_calls: Optional[str] = None # Store as JSON string
128
+ tool_call_id: Optional[str] = None # For 'tool' role messages
129
+ name: Optional[str] = None # Name of the tool called
130
+ created_at: datetime = Field(default_factory=datetime.utcnow)
131
+
132
+ conversation: Conversation = Relationship(back_populates="messages")
src/routes/chat.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from sqlalchemy.ext.asyncio import AsyncSession
3
+ from sqlalchemy.future import select
4
+ from src.database import get_session
5
+ from src.auth import get_current_user
6
+ from src.models import User, Conversation, Message, MessageRole
7
+ from src.agents.chatbot import TodoChatbot
8
+ from agents import ToolCallItem
9
+ from pydantic import BaseModel
10
+ import uuid
11
+ from typing import Optional
12
+ from datetime import datetime
13
+
14
+ router = APIRouter(prefix="/chat", tags=["chat"])
15
+
16
+ class ChatRequest(BaseModel):
17
+ message: str
18
+ conversation_id: Optional[str] = None
19
+
20
+ @router.post("/")
21
+ async def chat(
22
+ request: ChatRequest,
23
+ db: AsyncSession = Depends(get_session),
24
+ current_user: User = Depends(get_current_user)
25
+ ):
26
+ # 1. Resolve Conversation
27
+ if request.conversation_id:
28
+ try:
29
+ convo_uuid = uuid.UUID(request.conversation_id)
30
+ except ValueError:
31
+ raise HTTPException(status_code=400, detail="Invalid conversation_id format")
32
+
33
+ statement = select(Conversation).where(Conversation.id == convo_uuid, Conversation.user_id == current_user.id)
34
+ result = await db.execute(statement)
35
+ conversation = result.scalars().first()
36
+ if not conversation:
37
+ raise HTTPException(status_code=404, detail="Conversation not found or access denied")
38
+ else:
39
+ # Create new conversation
40
+ conversation = Conversation(user_id=current_user.id)
41
+ db.add(conversation)
42
+ await db.commit()
43
+ await db.refresh(conversation)
44
+
45
+ # 2. Store User Message
46
+ user_msg = Message(
47
+ conversation_id=conversation.id,
48
+ role=MessageRole.USER,
49
+ content=request.message
50
+ )
51
+ db.add(user_msg)
52
+ await db.commit()
53
+
54
+ # 3. Load History (Last 15 messages for more context)
55
+ history_statement = select(Message).where(Message.conversation_id == conversation.id).order_by(Message.created_at.desc()).limit(16)
56
+ result = await db.execute(history_statement)
57
+ messages_objs = result.scalars().all()
58
+ messages_objs = list(messages_objs)
59
+ messages_objs.reverse()
60
+
61
+ agent_messages = []
62
+ for msg in messages_objs:
63
+ role_val = msg.role.value
64
+ m = {
65
+ "role": role_val,
66
+ "content": msg.content or ""
67
+ }
68
+
69
+ # SDK Requirement: 'name' is only for 'tool' (and sometimes 'system')
70
+ # 'assistant' messages with a 'name' field cause Unhandled item errors.
71
+ if role_val == "tool":
72
+ m["tool_call_id"] = msg.tool_call_id
73
+ m["name"] = msg.name
74
+ elif role_val == "assistant":
75
+ if msg.tool_calls:
76
+ import json
77
+ m["tool_calls"] = json.loads(msg.tool_calls)
78
+ # DO NOT add 'name' to assistant role
79
+ elif role_val == "system":
80
+ if msg.name:
81
+ m["name"] = msg.name
82
+
83
+ agent_messages.append(m)
84
+
85
+ # 4. Run Agent
86
+ bot = TodoChatbot(db, current_user.id)
87
+ try:
88
+ # Result is a RunResult object
89
+ result = await bot.get_response(agent_messages)
90
+
91
+ # 5. Identify and Store NEW messages
92
+ full_history = result.to_input_list()
93
+ new_messages = full_history[len(agent_messages):]
94
+
95
+ import json
96
+ for msg in new_messages:
97
+ try:
98
+ role_str = msg.get("role", "assistant").lower()
99
+ # Validate role
100
+ if role_str not in [r.value for r in MessageRole]:
101
+ role_str = "assistant"
102
+
103
+ content = msg.get("content")
104
+ if isinstance(content, list):
105
+ # Flatten text content
106
+ text_parts = [item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "output_text"]
107
+ content = "".join(text_parts).strip()
108
+ elif content is None:
109
+ content = ""
110
+
111
+ new_msg = Message(
112
+ conversation_id=conversation.id,
113
+ role=MessageRole(role_str),
114
+ content=content,
115
+ tool_calls=json.dumps(msg.get("tool_calls")) if msg.get("tool_calls") else None,
116
+ tool_call_id=msg.get("tool_call_id"),
117
+ name=msg.get("name")
118
+ )
119
+ db.add(new_msg)
120
+ except Exception as msg_err:
121
+ print(f"Error processing message for DB: {msg_err}")
122
+ print(f"Message dict: {msg}")
123
+
124
+ await db.commit()
125
+
126
+ # Update conversation timestamp
127
+ conversation.updated_at = datetime.utcnow()
128
+ db.add(conversation)
129
+ await db.commit()
130
+
131
+ # Extract the final response text
132
+ response_text = result.final_output
133
+ if not response_text and len(new_messages) > 0:
134
+ last_msg = new_messages[-1]
135
+ response_text = last_msg.get("content")
136
+ if isinstance(response_text, list):
137
+ text_parts = [item.get("text", "") for item in response_text if isinstance(item, dict) and item.get("type") == "output_text"]
138
+ response_text = "".join(text_parts).strip()
139
+
140
+ # Extract tool calls for frontend indicator
141
+ frontend_tool_calls = []
142
+ if hasattr(result, "new_items"):
143
+ for item in result.new_items:
144
+ if isinstance(item, ToolCallItem):
145
+ tool_name = "unknown"
146
+ if hasattr(item.raw_item, "name"):
147
+ tool_name = item.raw_item.name
148
+ elif hasattr(item.raw_item, "function") and hasattr(item.raw_item.function, "name"):
149
+ tool_name = item.raw_item.function.name
150
+
151
+ frontend_tool_calls.append({"function": {"name": tool_name}})
152
+
153
+ return {
154
+ "conversation_id": str(conversation.id),
155
+ "response": response_text or "I processed your request.",
156
+ "tool_calls": frontend_tool_calls
157
+ }
158
+
159
+ except Exception as e:
160
+ await db.rollback() # Crucial: Reset session state on error
161
+ import traceback
162
+ print(f"Chat Route Error: {e}")
163
+ traceback.print_exc()
164
+ raise HTTPException(status_code=500, detail=f"AI Agent failed: {str(e)}")
src/routes/tasks.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import uuid
3
+ from datetime import datetime
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+ from sqlalchemy.future import select
7
+ from src.database import get_session
8
+ from src.models import Task, TaskCreate, TaskRead, TaskUpdate, User, TaskStatus
9
+ from src.auth import get_current_user
10
+
11
+ router = APIRouter(prefix="/tasks", tags=["tasks"])
12
+
13
+ @router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
14
+ async def create_task(
15
+ task: TaskCreate,
16
+ current_user: User = Depends(get_current_user),
17
+ session: AsyncSession = Depends(get_session)
18
+ ):
19
+ print(f"Creating task for user {current_user.id}: {task}")
20
+ try:
21
+ new_task = Task(**task.model_dump(), owner_id=current_user.id)
22
+ session.add(new_task)
23
+ await session.commit()
24
+ await session.refresh(new_task)
25
+ print(f"Task created successfully: {new_task.id}")
26
+ return new_task
27
+ except Exception as e:
28
+ print(f"Error creating task: {e}")
29
+ await session.rollback()
30
+ raise HTTPException(status_code=500, detail=str(e))
31
+
32
+ @router.get("/", response_model=List[TaskRead])
33
+ async def read_tasks(
34
+ current_user: User = Depends(get_current_user),
35
+ session: AsyncSession = Depends(get_session)
36
+ ):
37
+ stmt = select(Task).where(Task.owner_id == current_user.id).order_by(Task.created_at.desc())
38
+ result = await session.execute(stmt)
39
+ tasks = result.scalars().all()
40
+ return tasks
41
+
42
+ @router.get("/{task_id}", response_model=TaskRead)
43
+ async def read_task(
44
+ task_id: uuid.UUID,
45
+ current_user: User = Depends(get_current_user),
46
+ session: AsyncSession = Depends(get_session)
47
+ ):
48
+ stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
49
+ result = await session.execute(stmt)
50
+ task = result.scalars().first()
51
+ if not task:
52
+ raise HTTPException(status_code=404, detail="Task not found")
53
+ return task
54
+
55
+ @router.put("/{task_id}", response_model=TaskRead)
56
+ async def update_task(
57
+ task_id: uuid.UUID,
58
+ task_update: TaskUpdate,
59
+ current_user: User = Depends(get_current_user),
60
+ session: AsyncSession = Depends(get_session)
61
+ ):
62
+ stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
63
+ result = await session.execute(stmt)
64
+ task = result.scalars().first()
65
+ if not task:
66
+ raise HTTPException(status_code=404, detail="Task not found")
67
+
68
+ task_data = task_update.model_dump(exclude_unset=True)
69
+ for key, value in task_data.items():
70
+ setattr(task, key, value)
71
+
72
+ task.updated_at = datetime.utcnow()
73
+ session.add(task)
74
+ await session.commit()
75
+ await session.refresh(task)
76
+ return task
77
+
78
+ @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
79
+ async def delete_task(
80
+ task_id: uuid.UUID,
81
+ current_user: User = Depends(get_current_user),
82
+ session: AsyncSession = Depends(get_session)
83
+ ):
84
+ stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
85
+ result = await session.execute(stmt)
86
+ task = result.scalars().first()
87
+ if not task:
88
+ raise HTTPException(status_code=404, detail="Task not found")
89
+
90
+ await session.delete(task)
91
+ await session.commit()
92
+ return None
93
+
94
+ @router.patch("/{task_id}/complete", response_model=TaskRead)
95
+ async def toggle_task_completion(
96
+ task_id: uuid.UUID,
97
+ current_user: User = Depends(get_current_user),
98
+ session: AsyncSession = Depends(get_session)
99
+ ):
100
+ stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
101
+ result = await session.execute(stmt)
102
+ task = result.scalars().first()
103
+ if not task:
104
+ raise HTTPException(status_code=404, detail="Task not found")
105
+
106
+ if task.status == TaskStatus.DONE:
107
+ task.status = TaskStatus.TODO
108
+ else:
109
+ task.status = TaskStatus.DONE
110
+
111
+ task.updated_at = datetime.utcnow()
112
+ session.add(task)
113
+ await session.commit()
114
+ await session.refresh(task)
115
+ return task
tests/integration/test_chat.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from httpx import AsyncClient
3
+ from src.main import app
4
+ from src.models import User
5
+
6
+ @pytest.mark.asyncio
7
+ async def test_chat_unauthorized():
8
+ async with AsyncClient(app=app, base_url="http://test") as ac:
9
+ response = await ac.post("/chat/", json={"message": "hello"})
10
+ assert response.status_code == 401 # Should fail without JWT
11
+
12
+ # More tests would involve mocking get_current_user and the AI runner