Spaces:
Sleeping
Sleeping
File size: 2,297 Bytes
a291087 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from typing import Optional
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlmodel import Session
from .dependencies import CurrentUserDep, get_session
from ..models.conversation import Conversation
from ..services.chat_service import ChatService
from ..core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["chat"])
class ChatRequest(BaseModel):
message: str = Field(min_length=1, max_length=1000)
conversation_id: Optional[int] = None
model: Optional[str] = Field(default=None, description="Gemini model to use")
@router.post("/{user_id}/chat")
async def chat(
user_id: int,
body: ChatRequest,
current_user: CurrentUserDep,
session: Session = Depends(get_session),
):
try:
if user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if body.conversation_id is not None and body.conversation_id <= 0:
raise HTTPException(status_code=400, detail="conversation_id must be a positive integer")
if body.conversation_id is not None:
convo = session.get(Conversation, body.conversation_id)
if convo is None or convo.user_id != current_user.id:
raise HTTPException(
status_code=404,
detail=f"Conversation {body.conversation_id} not found",
)
# Validate model if provided
model = body.model or settings.GEMINI_DEFAULT_MODEL
if model not in settings.ALLOWED_GEMINI_MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Allowed models: {', '.join(settings.ALLOWED_GEMINI_MODELS)}"
)
service = ChatService(session)
return await service.chat(
user_id=current_user.id,
message=body.message,
conversation_id=body.conversation_id,
model=model,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Chat endpoint error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
|