Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +26 -0
- check_function_tool_type.py +3 -0
- init_db.py +20 -0
- pyproject.toml +18 -0
- requirements.txt +18 -0
- src/agents/__init__.py +1 -0
- src/agents/chatbot.py +119 -0
- src/agents/tools.py +147 -0
- src/auth.py +40 -0
- src/config.py +15 -0
- src/database.py +26 -0
- src/main.py +45 -0
- src/models.py +132 -0
- src/routes/chat.py +164 -0
- src/routes/tasks.py +115 -0
- tests/integration/test_chat.py +12 -0
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the official Python image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy requirements and install dependencies
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 10 |
+
|
| 11 |
+
# Copy the rest of the application code
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
# Create a non-root user for security (Hugging Face Spaces often run as user 1000)
|
| 15 |
+
RUN useradd -m -u 1000 user
|
| 16 |
+
USER user
|
| 17 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 18 |
+
|
| 19 |
+
# Expose the port Hugging Face Spaces expects (7860)
|
| 20 |
+
EXPOSE 7860
|
| 21 |
+
|
| 22 |
+
# Command to run the application
|
| 23 |
+
# We use uvicorn to run the app
|
| 24 |
+
# host 0.0.0.0 is required for container networking
|
| 25 |
+
# port 7860 is required by Hugging Face Spaces
|
| 26 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
check_function_tool_type.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import agents
|
| 2 |
+
print(f"agents.function_tool is: {type(agents.function_tool)}")
|
| 3 |
+
print(f"agents.FunctionTool is: {type(agents.FunctionTool)}")
|
init_db.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import sys
|
| 3 |
+
from src.database import init_db
|
| 4 |
+
from src.models import User, Task, Session, Account, Verification, Conversation, Message # Ensure all models are registered
|
| 5 |
+
from src.database import engine
|
| 6 |
+
from sqlmodel import SQLModel
|
| 7 |
+
|
| 8 |
+
async def recreate_db():
|
| 9 |
+
print("Dropping existing tables...")
|
| 10 |
+
async with engine.begin() as conn:
|
| 11 |
+
await conn.run_sync(SQLModel.metadata.drop_all)
|
| 12 |
+
print("Initializing database...")
|
| 13 |
+
async with engine.begin() as conn:
|
| 14 |
+
await conn.run_sync(SQLModel.metadata.create_all)
|
| 15 |
+
print("Database initialized successfully with Better Auth tables.")
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
if sys.platform == 'win32':
|
| 19 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 20 |
+
asyncio.run(recreate_db())
|
pyproject.toml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "mytodo-backend"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "MyTodoApp Backend"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
|
| 7 |
+
[tool.ruff]
|
| 8 |
+
line-length = 88
|
| 9 |
+
target-version = "py310"
|
| 10 |
+
|
| 11 |
+
[tool.black]
|
| 12 |
+
line-length = 88
|
| 13 |
+
target-version = ['py310']
|
| 14 |
+
|
| 15 |
+
[tool.pytest.ini_options]
|
| 16 |
+
asyncio_mode = "auto"
|
| 17 |
+
testpaths = ["tests"]
|
| 18 |
+
python_files = ["test_*.py"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.0
|
| 2 |
+
uvicorn[standard]>=0.30.0
|
| 3 |
+
sqlmodel>=0.0.22
|
| 4 |
+
psycopg[binary]>=3.2.0
|
| 5 |
+
asyncpg>=0.29.0
|
| 6 |
+
pydantic-settings>=2.5.0
|
| 7 |
+
python-jose[cryptography]>=3.3.0
|
| 8 |
+
passlib[bcrypt]>=1.7.4
|
| 9 |
+
python-multipart>=0.0.12
|
| 10 |
+
pytest>=8.0.0
|
| 11 |
+
pytest-asyncio>=0.24.0
|
| 12 |
+
httpx>=0.27.0
|
| 13 |
+
openai-agents
|
| 14 |
+
mcp
|
| 15 |
+
openai
|
| 16 |
+
sqlalchemy>=2.0.0
|
| 17 |
+
greenlet
|
| 18 |
+
python-dotenv
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Agents module
|
src/agents/chatbot.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
from typing import List, Optional, Dict, Any
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from agents import Agent, Runner, set_default_openai_client, function_tool, ModelSettings, set_tracing_disabled, OpenAIChatCompletionsModel
|
| 6 |
+
from src.config import settings
|
| 7 |
+
from src.agents.tools import (
|
| 8 |
+
create_task_tool,
|
| 9 |
+
list_tasks_tool,
|
| 10 |
+
update_task_tool,
|
| 11 |
+
delete_task_tool
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# Configure Gemini via OpenAI SDK (OpenRouter)
|
| 15 |
+
GEMINI_API_KEY = settings.GEMINI_API_KEY
|
| 16 |
+
client = None
|
| 17 |
+
if GEMINI_API_KEY:
|
| 18 |
+
# Disable tracing to avoid 401 errors from OpenAI-specific telemetry
|
| 19 |
+
set_tracing_disabled(True)
|
| 20 |
+
|
| 21 |
+
# Set default client for the entire process (using OpenRouter)
|
| 22 |
+
client = AsyncOpenAI(
|
| 23 |
+
api_key=GEMINI_API_KEY,
|
| 24 |
+
base_url="https://openrouter.ai/api/v1",
|
| 25 |
+
default_headers={
|
| 26 |
+
"HTTP-Referer": "http://localhost:3000",
|
| 27 |
+
"X-Title": "MyTodoApp",
|
| 28 |
+
}
|
| 29 |
+
)
|
| 30 |
+
set_default_openai_client(client)
|
| 31 |
+
# Set OPENAI_API_KEY to satisfy internal SDK requirements
|
| 32 |
+
os.environ["OPENAI_API_KEY"] = GEMINI_API_KEY
|
| 33 |
+
|
| 34 |
+
class TodoChatbot:
|
| 35 |
+
def __init__(self, db, user_id: str):
|
| 36 |
+
self.db = db
|
| 37 |
+
self.user_id = user_id
|
| 38 |
+
# Lock to ensure sequential database access per chatbot instance
|
| 39 |
+
self.lock = asyncio.Lock()
|
| 40 |
+
|
| 41 |
+
# Define local functions that wrap the tools with session, user context, and Lock
|
| 42 |
+
async def create_task(title: str, description: Optional[str] = None, status: str = "TODO", priority: str = "MEDIUM") -> str:
|
| 43 |
+
"""
|
| 44 |
+
Create a new task for the user.
|
| 45 |
+
Args:
|
| 46 |
+
title: The title of the task.
|
| 47 |
+
description: A detailed description of the task.
|
| 48 |
+
status: The status of the task (TODO, IN_PROGRESS, DONE).
|
| 49 |
+
priority: The priority of the task (LOW, MEDIUM, HIGH).
|
| 50 |
+
"""
|
| 51 |
+
async with self.lock:
|
| 52 |
+
return await create_task_tool(self.db, self.user_id, title, description, status, priority)
|
| 53 |
+
|
| 54 |
+
async def list_tasks(status: Optional[str] = None) -> str:
|
| 55 |
+
"""
|
| 56 |
+
List tasks for the user.
|
| 57 |
+
Args:
|
| 58 |
+
status: Optional filter by status (TODO, IN_PROGRESS, DONE).
|
| 59 |
+
"""
|
| 60 |
+
async with self.lock:
|
| 61 |
+
return await list_tasks_tool(self.db, self.user_id, status)
|
| 62 |
+
|
| 63 |
+
async def update_task(task_id: str, title: Optional[str] = None, description: Optional[str] = None, status: Optional[str] = None, priority: Optional[str] = None) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Update an existing task.
|
| 66 |
+
Args:
|
| 67 |
+
task_id: The UUID of the task to update.
|
| 68 |
+
title: New title for the task.
|
| 69 |
+
description: New description for the task.
|
| 70 |
+
status: New status for the task.
|
| 71 |
+
priority: New priority for the task.
|
| 72 |
+
"""
|
| 73 |
+
async with self.lock:
|
| 74 |
+
return await update_task_tool(self.db, self.user_id, task_id, title, description, status, priority)
|
| 75 |
+
|
| 76 |
+
async def delete_task(task_id: str) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Delete a task.
|
| 79 |
+
Args:
|
| 80 |
+
task_id: The UUID of the task to delete.
|
| 81 |
+
"""
|
| 82 |
+
async with self.lock:
|
| 83 |
+
return await delete_task_tool(self.db, self.user_id, task_id)
|
| 84 |
+
|
| 85 |
+
# Wrap OpenRouter model to bypass SDK prefix validation
|
| 86 |
+
gemini_model = OpenAIChatCompletionsModel(
|
| 87 |
+
model="google/gemini-2.0-flash-001",
|
| 88 |
+
openai_client=client
|
| 89 |
+
) if client else "gpt-4o"
|
| 90 |
+
|
| 91 |
+
self.agent = Agent(
|
| 92 |
+
name="TodoBot",
|
| 93 |
+
model=gemini_model,
|
| 94 |
+
instructions=(
|
| 95 |
+
"You are an extremely concise Todo List assistant. "
|
| 96 |
+
"Actions: Create, List, Update (Mark Done), Delete. "
|
| 97 |
+
"Instructions: "
|
| 98 |
+
"1. BE BRIEF. No conversational filler like 'Sure, I can help'. "
|
| 99 |
+
"2. When asked to update/complete/delete a task, ALWAYS call 'list_tasks' first to find the ID. "
|
| 100 |
+
"3. To mark a task as DONE, use 'update_task' with status='DONE'. "
|
| 101 |
+
"4. Confirm actions with a single short sentence: 'Done: Task marked complete.' or 'Task created.' "
|
| 102 |
+
"5. Never ask follow-up questions unless the request is truly ambiguous."
|
| 103 |
+
),
|
| 104 |
+
tools=[
|
| 105 |
+
function_tool(create_task),
|
| 106 |
+
function_tool(list_tasks),
|
| 107 |
+
function_tool(update_task),
|
| 108 |
+
function_tool(delete_task)
|
| 109 |
+
],
|
| 110 |
+
model_settings=ModelSettings(temperature=0.1, max_tokens=1024, parallel_tool_calls=False)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
async def get_response(self, messages: List[Dict[str, Any]]) -> Any:
|
| 114 |
+
"""
|
| 115 |
+
Process a list of messages and return the agent's response.
|
| 116 |
+
"""
|
| 117 |
+
runner = Runner()
|
| 118 |
+
result = await runner.run(self.agent, messages)
|
| 119 |
+
return result
|
src/agents/tools.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import List, Optional, Any
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from src.models import Task, TaskStatus, TaskPriority
|
| 7 |
+
|
| 8 |
+
class ToolResult:
|
| 9 |
+
def __init__(self, success: bool, data: Any = None, error: str = None):
|
| 10 |
+
self.success = success
|
| 11 |
+
self.data = data
|
| 12 |
+
self.error = error
|
| 13 |
+
|
| 14 |
+
def to_json(self) -> str:
|
| 15 |
+
return json.dumps({
|
| 16 |
+
"success": self.success,
|
| 17 |
+
"data": self.data,
|
| 18 |
+
"error": self.error
|
| 19 |
+
})
|
| 20 |
+
|
| 21 |
+
# Base MCP Tool Wrapper
|
| 22 |
+
# In a real MCP setup, these would be registered with an MCP server.
|
| 23 |
+
# For this implementation, they will be used by the OpenAI Agents SDK.
|
| 24 |
+
|
| 25 |
+
def format_task(task: Task) -> dict:
|
| 26 |
+
return {
|
| 27 |
+
"id": str(task.id),
|
| 28 |
+
"title": task.title,
|
| 29 |
+
"description": task.description,
|
| 30 |
+
"status": task.status,
|
| 31 |
+
"priority": task.priority,
|
| 32 |
+
"created_at": task.created_at.isoformat()
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
async def create_task_tool(db: AsyncSession, user_id: str, title: str, description: Optional[str] = None, status: str = "TODO", priority: str = "MEDIUM") -> str:
|
| 36 |
+
|
| 37 |
+
"""Create a new task for the authenticated user."""
|
| 38 |
+
|
| 39 |
+
print(f"Tool [create_task]: user={user_id} title='{title}' status={status}")
|
| 40 |
+
|
| 41 |
+
# Explicitly convert strings to Enums
|
| 42 |
+
|
| 43 |
+
task_status = TaskStatus(status.upper()) if isinstance(status, str) else status
|
| 44 |
+
|
| 45 |
+
task_priority = TaskPriority(priority.upper()) if isinstance(priority, str) else priority
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
task = Task(title=title, description=description, status=task_status, priority=task_priority, owner_id=user_id)
|
| 50 |
+
|
| 51 |
+
db.add(task)
|
| 52 |
+
|
| 53 |
+
await db.commit()
|
| 54 |
+
|
| 55 |
+
await db.refresh(task)
|
| 56 |
+
|
| 57 |
+
return json.dumps(format_task(task))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def list_tasks_tool(db: AsyncSession, user_id: str, status: Optional[str] = None) -> str:
|
| 62 |
+
|
| 63 |
+
"""List all tasks for the authenticated user, optionally filtered by status."""
|
| 64 |
+
|
| 65 |
+
print(f"Tool [list_tasks]: user={user_id} status_filter={status}")
|
| 66 |
+
|
| 67 |
+
statement = select(Task).where(Task.owner_id == user_id)
|
| 68 |
+
|
| 69 |
+
if status:
|
| 70 |
+
|
| 71 |
+
task_status = TaskStatus(status.upper()) if isinstance(status, str) else status
|
| 72 |
+
|
| 73 |
+
statement = statement.where(Task.status == task_status)
|
| 74 |
+
|
| 75 |
+
result = await db.execute(statement)
|
| 76 |
+
|
| 77 |
+
tasks = result.scalars().all()
|
| 78 |
+
|
| 79 |
+
return json.dumps([format_task(t) for t in tasks])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
async def update_task_tool(db: AsyncSession, user_id: str, task_id: str, title: Optional[str] = None, description: Optional[str] = None, status: Optional[str] = None, priority: Optional[str] = None) -> str:
|
| 84 |
+
|
| 85 |
+
"""Update an existing task for the authenticated user."""
|
| 86 |
+
|
| 87 |
+
print(f"Tool [update_task]: id={task_id} status_update={status}")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
|
| 91 |
+
task_uuid = uuid.UUID(task_id)
|
| 92 |
+
|
| 93 |
+
except ValueError:
|
| 94 |
+
|
| 95 |
+
return json.dumps({"error": "Invalid task ID format"})
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
statement = select(Task).where(Task.id == task_uuid, Task.owner_id == user_id)
|
| 100 |
+
|
| 101 |
+
result = await db.execute(statement)
|
| 102 |
+
|
| 103 |
+
task = result.scalars().first()
|
| 104 |
+
|
| 105 |
+
if not task:
|
| 106 |
+
|
| 107 |
+
print(f"Tool [update_task]: Task {task_id} NOT FOUND for user {user_id}")
|
| 108 |
+
|
| 109 |
+
return json.dumps({"error": "Task not found or access denied"})
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if title is not None:
|
| 114 |
+
|
| 115 |
+
task.title = title
|
| 116 |
+
|
| 117 |
+
if description is not None:
|
| 118 |
+
|
| 119 |
+
task.description = description
|
| 120 |
+
|
| 121 |
+
if status is not None:
|
| 122 |
+
|
| 123 |
+
task.status = TaskStatus(status.upper()) if isinstance(status, str) else status
|
| 124 |
+
|
| 125 |
+
if priority is not None:
|
| 126 |
+
|
| 127 |
+
task.priority = TaskPriority(priority.upper()) if isinstance(priority, str) else priority
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
db.add(task)
|
| 132 |
+
|
| 133 |
+
await db.commit()
|
| 134 |
+
|
| 135 |
+
await db.refresh(task)
|
| 136 |
+
|
| 137 |
+
print(f"Tool [update_task]: Task {task_id} UPDATED successfully. New status: {task.status}")
|
| 138 |
+
|
| 139 |
+
return json.dumps(format_task(task))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
async def delete_task_tool(db: AsyncSession, user_id: str, task_id: str) -> str:
|
| 144 |
+
|
| 145 |
+
"""Delete a task for the authenticated user."""
|
| 146 |
+
|
| 147 |
+
print(f"Tool [delete_task]: id={task_id}")
|
src/auth.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from fastapi import Depends, HTTPException, status
|
| 3 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from src.database import get_session
|
| 7 |
+
from src.models import User, Session
|
| 8 |
+
|
| 9 |
+
security = HTTPBearer()
|
| 10 |
+
|
| 11 |
+
async def get_current_user(auth: HTTPAuthorizationCredentials = Depends(security), session: AsyncSession = Depends(get_session)) -> User:
|
| 12 |
+
token = auth.credentials
|
| 13 |
+
|
| 14 |
+
# Query Better Auth session table
|
| 15 |
+
stmt = select(Session).where(Session.token == token)
|
| 16 |
+
result = await session.execute(stmt)
|
| 17 |
+
db_session = result.scalars().first()
|
| 18 |
+
|
| 19 |
+
if not db_session:
|
| 20 |
+
raise HTTPException(
|
| 21 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 22 |
+
detail="Invalid session token",
|
| 23 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if db_session.expiresAt < datetime.utcnow():
|
| 27 |
+
raise HTTPException(
|
| 28 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 29 |
+
detail="Session expired",
|
| 30 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
user = await session.get(User, db_session.userId)
|
| 34 |
+
if user is None:
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 37 |
+
detail="User not found",
|
| 38 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 39 |
+
)
|
| 40 |
+
return user
|
src/config.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost/mytodo_db"
|
| 6 |
+
SECRET_KEY: str = "your-secret-key-must-be-changed-in-production"
|
| 7 |
+
ALGORITHM: str = "HS256"
|
| 8 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
| 9 |
+
GEMINI_API_KEY: Optional[str] = None
|
| 10 |
+
|
| 11 |
+
class Config:
|
| 12 |
+
env_file = ".env"
|
| 13 |
+
extra = "ignore"
|
| 14 |
+
|
| 15 |
+
settings = Settings()
|
src/database.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlmodel import SQLModel, create_engine
|
| 2 |
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
| 3 |
+
from sqlalchemy.orm import sessionmaker
|
| 4 |
+
from src.config import settings
|
| 5 |
+
|
| 6 |
+
DATABASE_URL = settings.DATABASE_URL
|
| 7 |
+
|
| 8 |
+
# Pass SSL mode explicitly for asyncpg
|
| 9 |
+
engine = create_async_engine(
|
| 10 |
+
DATABASE_URL,
|
| 11 |
+
echo=True,
|
| 12 |
+
future=True,
|
| 13 |
+
connect_args={"ssl": "require"}
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
async def get_session() -> AsyncSession:
|
| 17 |
+
async_session = sessionmaker(
|
| 18 |
+
engine, class_=AsyncSession, expire_on_commit=False
|
| 19 |
+
)
|
| 20 |
+
async with async_session() as session:
|
| 21 |
+
yield session
|
| 22 |
+
|
| 23 |
+
async def init_db():
|
| 24 |
+
async with engine.begin() as conn:
|
| 25 |
+
# await conn.run_sync(SQLModel.metadata.drop_all)
|
| 26 |
+
await conn.run_sync(SQLModel.metadata.create_all)
|
src/main.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from src.config import settings
|
| 7 |
+
from src.routes import tasks, chat
|
| 8 |
+
|
| 9 |
+
# Set event loop policy for Windows to support asyncpg/SQLAlchemy correctly
|
| 10 |
+
if os.name == 'nt':
|
| 11 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 12 |
+
|
| 13 |
+
app = FastAPI(title="MyTodoApp Backend", version="0.1.0")
|
| 14 |
+
|
| 15 |
+
# CORS
|
| 16 |
+
# In production, set BACKEND_CORS_ORIGINS="https://your-frontend.vercel.app"
|
| 17 |
+
origins = settings.BACKEND_CORS_ORIGINS.split(",") if hasattr(settings, "BACKEND_CORS_ORIGINS") and settings.BACKEND_CORS_ORIGINS else [
|
| 18 |
+
"http://localhost:3000",
|
| 19 |
+
"http://127.0.0.1:3000",
|
| 20 |
+
"*" # Allow all for hackathon/testing ease - REMOVE in strict production
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=origins,
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
app.include_router(tasks.router)
|
| 32 |
+
app.include_router(chat.router)
|
| 33 |
+
|
| 34 |
+
@app.exception_handler(Exception)
|
| 35 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 36 |
+
# Print exception for debugging
|
| 37 |
+
print(f"Global exception: {exc}")
|
| 38 |
+
return JSONResponse(
|
| 39 |
+
status_code=500,
|
| 40 |
+
content={"detail": str(exc)},
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@app.get("/health")
|
| 44 |
+
async def health_check():
|
| 45 |
+
return {"status": "ok"}
|
src/models.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from sqlmodel import Field, SQLModel, Relationship
|
| 6 |
+
|
| 7 |
+
class TaskStatus(str, Enum):
|
| 8 |
+
TODO = "TODO"
|
| 9 |
+
IN_PROGRESS = "IN_PROGRESS"
|
| 10 |
+
DONE = "DONE"
|
| 11 |
+
|
| 12 |
+
class TaskPriority(str, Enum):
|
| 13 |
+
LOW = "LOW"
|
| 14 |
+
MEDIUM = "MEDIUM"
|
| 15 |
+
HIGH = "HIGH"
|
| 16 |
+
|
| 17 |
+
# Better Auth Compatible Models
|
| 18 |
+
class UserBase(SQLModel):
|
| 19 |
+
email: str = Field(unique=True, index=True)
|
| 20 |
+
name: Optional[str] = None
|
| 21 |
+
image: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
class User(UserBase, table=True):
|
| 24 |
+
id: str = Field(primary_key=True)
|
| 25 |
+
emailVerified: bool = Field(default=False)
|
| 26 |
+
createdAt: datetime = Field(default_factory=datetime.utcnow)
|
| 27 |
+
updatedAt: datetime = Field(default_factory=datetime.utcnow)
|
| 28 |
+
|
| 29 |
+
tasks: List["Task"] = Relationship(back_populates="owner")
|
| 30 |
+
sessions: List["Session"] = Relationship(back_populates="user")
|
| 31 |
+
conversations: List["Conversation"] = Relationship(back_populates="user")
|
| 32 |
+
|
| 33 |
+
class UserCreate(UserBase):
|
| 34 |
+
id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 35 |
+
password: str
|
| 36 |
+
|
| 37 |
+
class UserRead(UserBase):
|
| 38 |
+
id: str
|
| 39 |
+
createdAt: datetime
|
| 40 |
+
updatedAt: datetime
|
| 41 |
+
|
| 42 |
+
class Session(SQLModel, table=True):
|
| 43 |
+
id: str = Field(primary_key=True)
|
| 44 |
+
expiresAt: datetime
|
| 45 |
+
token: str
|
| 46 |
+
createdAt: datetime = Field(default_factory=datetime.utcnow)
|
| 47 |
+
updatedAt: datetime = Field(default_factory=datetime.utcnow)
|
| 48 |
+
ipAddress: Optional[str] = None
|
| 49 |
+
userAgent: Optional[str] = None
|
| 50 |
+
userId: str = Field(foreign_key="user.id")
|
| 51 |
+
|
| 52 |
+
user: User = Relationship(back_populates="sessions")
|
| 53 |
+
|
| 54 |
+
class Account(SQLModel, table=True):
|
| 55 |
+
id: str = Field(primary_key=True)
|
| 56 |
+
accountId: str
|
| 57 |
+
providerId: str
|
| 58 |
+
userId: str = Field(foreign_key="user.id")
|
| 59 |
+
accessToken: Optional[str] = None
|
| 60 |
+
refreshToken: Optional[str] = None
|
| 61 |
+
idToken: Optional[str] = None
|
| 62 |
+
accessTokenExpiresAt: Optional[datetime] = None
|
| 63 |
+
refreshTokenExpiresAt: Optional[datetime] = None
|
| 64 |
+
scope: Optional[str] = None
|
| 65 |
+
password: Optional[str] = None
|
| 66 |
+
createdAt: datetime = Field(default_factory=datetime.utcnow)
|
| 67 |
+
updatedAt: datetime = Field(default_factory=datetime.utcnow)
|
| 68 |
+
|
| 69 |
+
class Verification(SQLModel, table=True):
|
| 70 |
+
id: str = Field(primary_key=True)
|
| 71 |
+
identifier: str
|
| 72 |
+
value: str
|
| 73 |
+
expiresAt: datetime
|
| 74 |
+
createdAt: datetime = Field(default_factory=datetime.utcnow)
|
| 75 |
+
updatedAt: datetime = Field(default_factory=datetime.utcnow)
|
| 76 |
+
|
| 77 |
+
class TaskBase(SQLModel):
|
| 78 |
+
title: str = Field(index=True)
|
| 79 |
+
description: Optional[str] = None
|
| 80 |
+
status: TaskStatus = Field(default=TaskStatus.TODO)
|
| 81 |
+
priority: Optional[TaskPriority] = Field(default=TaskPriority.MEDIUM)
|
| 82 |
+
|
| 83 |
+
class Task(TaskBase, table=True):
|
| 84 |
+
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
| 85 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 86 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 87 |
+
owner_id: str = Field(foreign_key="user.id")
|
| 88 |
+
|
| 89 |
+
owner: User = Relationship(back_populates="tasks")
|
| 90 |
+
|
| 91 |
+
class TaskCreate(TaskBase):
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
class TaskUpdate(SQLModel):
|
| 95 |
+
title: Optional[str] = None
|
| 96 |
+
description: Optional[str] = None
|
| 97 |
+
status: Optional[TaskStatus] = None
|
| 98 |
+
priority: Optional[TaskPriority] = None
|
| 99 |
+
|
| 100 |
+
class TaskRead(TaskBase):
|
| 101 |
+
id: uuid.UUID
|
| 102 |
+
created_at: datetime
|
| 103 |
+
updated_at: datetime
|
| 104 |
+
owner_id: str
|
| 105 |
+
|
| 106 |
+
class MessageRole(str, Enum):
|
| 107 |
+
USER = "user"
|
| 108 |
+
ASSISTANT = "assistant"
|
| 109 |
+
SYSTEM = "system"
|
| 110 |
+
TOOL = "tool"
|
| 111 |
+
|
| 112 |
+
class Conversation(SQLModel, table=True):
|
| 113 |
+
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
| 114 |
+
user_id: str = Field(foreign_key="user.id")
|
| 115 |
+
title: Optional[str] = None
|
| 116 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 117 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 118 |
+
|
| 119 |
+
user: User = Relationship(back_populates="conversations")
|
| 120 |
+
messages: List["Message"] = Relationship(back_populates="conversation")
|
| 121 |
+
|
| 122 |
+
class Message(SQLModel, table=True):
|
| 123 |
+
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
| 124 |
+
conversation_id: uuid.UUID = Field(foreign_key="conversation.id")
|
| 125 |
+
role: MessageRole
|
| 126 |
+
content: Optional[str] = None
|
| 127 |
+
tool_calls: Optional[str] = None # Store as JSON string
|
| 128 |
+
tool_call_id: Optional[str] = None # For 'tool' role messages
|
| 129 |
+
name: Optional[str] = None # Name of the tool called
|
| 130 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 131 |
+
|
| 132 |
+
conversation: Conversation = Relationship(back_populates="messages")
|
src/routes/chat.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 3 |
+
from sqlalchemy.future import select
|
| 4 |
+
from src.database import get_session
|
| 5 |
+
from src.auth import get_current_user
|
| 6 |
+
from src.models import User, Conversation, Message, MessageRole
|
| 7 |
+
from src.agents.chatbot import TodoChatbot
|
| 8 |
+
from agents import ToolCallItem
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
import uuid
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/chat", tags=["chat"])
|
| 15 |
+
|
| 16 |
+
class ChatRequest(BaseModel):
|
| 17 |
+
message: str
|
| 18 |
+
conversation_id: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
@router.post("/")
|
| 21 |
+
async def chat(
|
| 22 |
+
request: ChatRequest,
|
| 23 |
+
db: AsyncSession = Depends(get_session),
|
| 24 |
+
current_user: User = Depends(get_current_user)
|
| 25 |
+
):
|
| 26 |
+
# 1. Resolve Conversation
|
| 27 |
+
if request.conversation_id:
|
| 28 |
+
try:
|
| 29 |
+
convo_uuid = uuid.UUID(request.conversation_id)
|
| 30 |
+
except ValueError:
|
| 31 |
+
raise HTTPException(status_code=400, detail="Invalid conversation_id format")
|
| 32 |
+
|
| 33 |
+
statement = select(Conversation).where(Conversation.id == convo_uuid, Conversation.user_id == current_user.id)
|
| 34 |
+
result = await db.execute(statement)
|
| 35 |
+
conversation = result.scalars().first()
|
| 36 |
+
if not conversation:
|
| 37 |
+
raise HTTPException(status_code=404, detail="Conversation not found or access denied")
|
| 38 |
+
else:
|
| 39 |
+
# Create new conversation
|
| 40 |
+
conversation = Conversation(user_id=current_user.id)
|
| 41 |
+
db.add(conversation)
|
| 42 |
+
await db.commit()
|
| 43 |
+
await db.refresh(conversation)
|
| 44 |
+
|
| 45 |
+
# 2. Store User Message
|
| 46 |
+
user_msg = Message(
|
| 47 |
+
conversation_id=conversation.id,
|
| 48 |
+
role=MessageRole.USER,
|
| 49 |
+
content=request.message
|
| 50 |
+
)
|
| 51 |
+
db.add(user_msg)
|
| 52 |
+
await db.commit()
|
| 53 |
+
|
| 54 |
+
# 3. Load History (Last 15 messages for more context)
|
| 55 |
+
history_statement = select(Message).where(Message.conversation_id == conversation.id).order_by(Message.created_at.desc()).limit(16)
|
| 56 |
+
result = await db.execute(history_statement)
|
| 57 |
+
messages_objs = result.scalars().all()
|
| 58 |
+
messages_objs = list(messages_objs)
|
| 59 |
+
messages_objs.reverse()
|
| 60 |
+
|
| 61 |
+
agent_messages = []
|
| 62 |
+
for msg in messages_objs:
|
| 63 |
+
role_val = msg.role.value
|
| 64 |
+
m = {
|
| 65 |
+
"role": role_val,
|
| 66 |
+
"content": msg.content or ""
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# SDK Requirement: 'name' is only for 'tool' (and sometimes 'system')
|
| 70 |
+
# 'assistant' messages with a 'name' field cause Unhandled item errors.
|
| 71 |
+
if role_val == "tool":
|
| 72 |
+
m["tool_call_id"] = msg.tool_call_id
|
| 73 |
+
m["name"] = msg.name
|
| 74 |
+
elif role_val == "assistant":
|
| 75 |
+
if msg.tool_calls:
|
| 76 |
+
import json
|
| 77 |
+
m["tool_calls"] = json.loads(msg.tool_calls)
|
| 78 |
+
# DO NOT add 'name' to assistant role
|
| 79 |
+
elif role_val == "system":
|
| 80 |
+
if msg.name:
|
| 81 |
+
m["name"] = msg.name
|
| 82 |
+
|
| 83 |
+
agent_messages.append(m)
|
| 84 |
+
|
| 85 |
+
# 4. Run Agent
|
| 86 |
+
bot = TodoChatbot(db, current_user.id)
|
| 87 |
+
try:
|
| 88 |
+
# Result is a RunResult object
|
| 89 |
+
result = await bot.get_response(agent_messages)
|
| 90 |
+
|
| 91 |
+
# 5. Identify and Store NEW messages
|
| 92 |
+
full_history = result.to_input_list()
|
| 93 |
+
new_messages = full_history[len(agent_messages):]
|
| 94 |
+
|
| 95 |
+
import json
|
| 96 |
+
for msg in new_messages:
|
| 97 |
+
try:
|
| 98 |
+
role_str = msg.get("role", "assistant").lower()
|
| 99 |
+
# Validate role
|
| 100 |
+
if role_str not in [r.value for r in MessageRole]:
|
| 101 |
+
role_str = "assistant"
|
| 102 |
+
|
| 103 |
+
content = msg.get("content")
|
| 104 |
+
if isinstance(content, list):
|
| 105 |
+
# Flatten text content
|
| 106 |
+
text_parts = [item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "output_text"]
|
| 107 |
+
content = "".join(text_parts).strip()
|
| 108 |
+
elif content is None:
|
| 109 |
+
content = ""
|
| 110 |
+
|
| 111 |
+
new_msg = Message(
|
| 112 |
+
conversation_id=conversation.id,
|
| 113 |
+
role=MessageRole(role_str),
|
| 114 |
+
content=content,
|
| 115 |
+
tool_calls=json.dumps(msg.get("tool_calls")) if msg.get("tool_calls") else None,
|
| 116 |
+
tool_call_id=msg.get("tool_call_id"),
|
| 117 |
+
name=msg.get("name")
|
| 118 |
+
)
|
| 119 |
+
db.add(new_msg)
|
| 120 |
+
except Exception as msg_err:
|
| 121 |
+
print(f"Error processing message for DB: {msg_err}")
|
| 122 |
+
print(f"Message dict: {msg}")
|
| 123 |
+
|
| 124 |
+
await db.commit()
|
| 125 |
+
|
| 126 |
+
# Update conversation timestamp
|
| 127 |
+
conversation.updated_at = datetime.utcnow()
|
| 128 |
+
db.add(conversation)
|
| 129 |
+
await db.commit()
|
| 130 |
+
|
| 131 |
+
# Extract the final response text
|
| 132 |
+
response_text = result.final_output
|
| 133 |
+
if not response_text and len(new_messages) > 0:
|
| 134 |
+
last_msg = new_messages[-1]
|
| 135 |
+
response_text = last_msg.get("content")
|
| 136 |
+
if isinstance(response_text, list):
|
| 137 |
+
text_parts = [item.get("text", "") for item in response_text if isinstance(item, dict) and item.get("type") == "output_text"]
|
| 138 |
+
response_text = "".join(text_parts).strip()
|
| 139 |
+
|
| 140 |
+
# Extract tool calls for frontend indicator
|
| 141 |
+
frontend_tool_calls = []
|
| 142 |
+
if hasattr(result, "new_items"):
|
| 143 |
+
for item in result.new_items:
|
| 144 |
+
if isinstance(item, ToolCallItem):
|
| 145 |
+
tool_name = "unknown"
|
| 146 |
+
if hasattr(item.raw_item, "name"):
|
| 147 |
+
tool_name = item.raw_item.name
|
| 148 |
+
elif hasattr(item.raw_item, "function") and hasattr(item.raw_item.function, "name"):
|
| 149 |
+
tool_name = item.raw_item.function.name
|
| 150 |
+
|
| 151 |
+
frontend_tool_calls.append({"function": {"name": tool_name}})
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
"conversation_id": str(conversation.id),
|
| 155 |
+
"response": response_text or "I processed your request.",
|
| 156 |
+
"tool_calls": frontend_tool_calls
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
await db.rollback() # Crucial: Reset session state on error
|
| 161 |
+
import traceback
|
| 162 |
+
print(f"Chat Route Error: {e}")
|
| 163 |
+
traceback.print_exc()
|
| 164 |
+
raise HTTPException(status_code=500, detail=f"AI Agent failed: {str(e)}")
|
src/routes/tasks.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import uuid
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
from sqlalchemy.future import select
|
| 7 |
+
from src.database import get_session
|
| 8 |
+
from src.models import Task, TaskCreate, TaskRead, TaskUpdate, User, TaskStatus
|
| 9 |
+
from src.auth import get_current_user
|
| 10 |
+
|
| 11 |
+
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
| 12 |
+
|
| 13 |
+
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
|
| 14 |
+
async def create_task(
|
| 15 |
+
task: TaskCreate,
|
| 16 |
+
current_user: User = Depends(get_current_user),
|
| 17 |
+
session: AsyncSession = Depends(get_session)
|
| 18 |
+
):
|
| 19 |
+
print(f"Creating task for user {current_user.id}: {task}")
|
| 20 |
+
try:
|
| 21 |
+
new_task = Task(**task.model_dump(), owner_id=current_user.id)
|
| 22 |
+
session.add(new_task)
|
| 23 |
+
await session.commit()
|
| 24 |
+
await session.refresh(new_task)
|
| 25 |
+
print(f"Task created successfully: {new_task.id}")
|
| 26 |
+
return new_task
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error creating task: {e}")
|
| 29 |
+
await session.rollback()
|
| 30 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 31 |
+
|
| 32 |
+
@router.get("/", response_model=List[TaskRead])
|
| 33 |
+
async def read_tasks(
|
| 34 |
+
current_user: User = Depends(get_current_user),
|
| 35 |
+
session: AsyncSession = Depends(get_session)
|
| 36 |
+
):
|
| 37 |
+
stmt = select(Task).where(Task.owner_id == current_user.id).order_by(Task.created_at.desc())
|
| 38 |
+
result = await session.execute(stmt)
|
| 39 |
+
tasks = result.scalars().all()
|
| 40 |
+
return tasks
|
| 41 |
+
|
| 42 |
+
@router.get("/{task_id}", response_model=TaskRead)
|
| 43 |
+
async def read_task(
|
| 44 |
+
task_id: uuid.UUID,
|
| 45 |
+
current_user: User = Depends(get_current_user),
|
| 46 |
+
session: AsyncSession = Depends(get_session)
|
| 47 |
+
):
|
| 48 |
+
stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
|
| 49 |
+
result = await session.execute(stmt)
|
| 50 |
+
task = result.scalars().first()
|
| 51 |
+
if not task:
|
| 52 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 53 |
+
return task
|
| 54 |
+
|
| 55 |
+
@router.put("/{task_id}", response_model=TaskRead)
|
| 56 |
+
async def update_task(
|
| 57 |
+
task_id: uuid.UUID,
|
| 58 |
+
task_update: TaskUpdate,
|
| 59 |
+
current_user: User = Depends(get_current_user),
|
| 60 |
+
session: AsyncSession = Depends(get_session)
|
| 61 |
+
):
|
| 62 |
+
stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
|
| 63 |
+
result = await session.execute(stmt)
|
| 64 |
+
task = result.scalars().first()
|
| 65 |
+
if not task:
|
| 66 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 67 |
+
|
| 68 |
+
task_data = task_update.model_dump(exclude_unset=True)
|
| 69 |
+
for key, value in task_data.items():
|
| 70 |
+
setattr(task, key, value)
|
| 71 |
+
|
| 72 |
+
task.updated_at = datetime.utcnow()
|
| 73 |
+
session.add(task)
|
| 74 |
+
await session.commit()
|
| 75 |
+
await session.refresh(task)
|
| 76 |
+
return task
|
| 77 |
+
|
| 78 |
+
@router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 79 |
+
async def delete_task(
|
| 80 |
+
task_id: uuid.UUID,
|
| 81 |
+
current_user: User = Depends(get_current_user),
|
| 82 |
+
session: AsyncSession = Depends(get_session)
|
| 83 |
+
):
|
| 84 |
+
stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
|
| 85 |
+
result = await session.execute(stmt)
|
| 86 |
+
task = result.scalars().first()
|
| 87 |
+
if not task:
|
| 88 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 89 |
+
|
| 90 |
+
await session.delete(task)
|
| 91 |
+
await session.commit()
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
@router.patch("/{task_id}/complete", response_model=TaskRead)
|
| 95 |
+
async def toggle_task_completion(
|
| 96 |
+
task_id: uuid.UUID,
|
| 97 |
+
current_user: User = Depends(get_current_user),
|
| 98 |
+
session: AsyncSession = Depends(get_session)
|
| 99 |
+
):
|
| 100 |
+
stmt = select(Task).where(Task.id == task_id, Task.owner_id == current_user.id)
|
| 101 |
+
result = await session.execute(stmt)
|
| 102 |
+
task = result.scalars().first()
|
| 103 |
+
if not task:
|
| 104 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
| 105 |
+
|
| 106 |
+
if task.status == TaskStatus.DONE:
|
| 107 |
+
task.status = TaskStatus.TODO
|
| 108 |
+
else:
|
| 109 |
+
task.status = TaskStatus.DONE
|
| 110 |
+
|
| 111 |
+
task.updated_at = datetime.utcnow()
|
| 112 |
+
session.add(task)
|
| 113 |
+
await session.commit()
|
| 114 |
+
await session.refresh(task)
|
| 115 |
+
return task
|
tests/integration/test_chat.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from httpx import AsyncClient
|
| 3 |
+
from src.main import app
|
| 4 |
+
from src.models import User
|
| 5 |
+
|
| 6 |
+
@pytest.mark.asyncio
|
| 7 |
+
async def test_chat_unauthorized():
|
| 8 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
| 9 |
+
response = await ac.post("/chat/", json={"message": "hello"})
|
| 10 |
+
assert response.status_code == 401 # Should fail without JWT
|
| 11 |
+
|
| 12 |
+
# More tests would involve mocking get_current_user and the AI runner
|