File size: 2,297 Bytes
a291087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Optional
import logging

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlmodel import Session

from .dependencies import CurrentUserDep, get_session
from ..models.conversation import Conversation
from ..services.chat_service import ChatService
from ..core.config import settings

logger = logging.getLogger(__name__)


router = APIRouter(prefix="/api", tags=["chat"])


class ChatRequest(BaseModel):
    message: str = Field(min_length=1, max_length=1000)
    conversation_id: Optional[int] = None
    model: Optional[str] = Field(default=None, description="Gemini model to use")


@router.post("/{user_id}/chat")
async def chat(
    user_id: int,
    body: ChatRequest,
    current_user: CurrentUserDep,
    session: Session = Depends(get_session),
):
    try:
        if user_id != current_user.id:
            raise HTTPException(status_code=403, detail="Access denied")

        if body.conversation_id is not None and body.conversation_id <= 0:
            raise HTTPException(status_code=400, detail="conversation_id must be a positive integer")

        if body.conversation_id is not None:
            convo = session.get(Conversation, body.conversation_id)
            if convo is None or convo.user_id != current_user.id:
                raise HTTPException(
                    status_code=404,
                    detail=f"Conversation {body.conversation_id} not found",
                )

        # Validate model if provided
        model = body.model or settings.GEMINI_DEFAULT_MODEL
        if model not in settings.ALLOWED_GEMINI_MODELS:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid model. Allowed models: {', '.join(settings.ALLOWED_GEMINI_MODELS)}"
            )

        service = ChatService(session)
        return await service.chat(
            user_id=current_user.id,
            message=body.message,
            conversation_id=body.conversation_id,
            model=model,
        )
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Chat endpoint error: {str(e)}", exc_info=True)
        raise HTTPException(
            status_code=500,
            detail=f"Internal server error: {str(e)}"
        )