| import os |
| import sys |
| from fastapi import FastAPI, HTTPException, Depends, status |
| from fastapi.responses import PlainTextResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field |
| from typing import Optional, List, Dict, Any, Union |
| from datetime import datetime, timezone, timedelta, date |
| import zoneinfo |
| import psycopg2 |
| from psycopg2 import pool as psycopg2_pool |
| from jose import JWTError, jwt |
| import uvicorn |
| from dotenv import load_dotenv |
| import time |
| import uuid |
| from src.recommendation_api import ( |
| TourRecommendationRequest, |
| TourRecommendationResponse, |
| get_tour_recommendations |
| ) |
|
|
| load_dotenv() |
|
|
| try: |
| from src.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM |
| from src.database import conn_pool |
| from src.graph_builder import graph_app |
| from src.embedding import embedding_model |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage |
| from src.tools import extract_entities_tool, search_tours_tool |
| except ImportError as e: |
| print(f"Error importing from src: {e}. Using placeholders. API will likely fail at runtime until this is fixed.") |
| DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_ENDPOINT_ID, GOOGLE_API_KEY, JWT_SECRET_KEY, ALGORITHM = [None]*9 |
| conn_pool = None |
| graph_app = None |
| embedding_model = None |
| class HumanMessage: |
| def __init__(self, content): |
| self.content = content |
| class AIMessage: |
| def __init__(self, content): |
| self.content = content |
| BaseMessage = Dict |
|
|
| app = FastAPI( |
| title="Travel Chatbot, Tour Search & Recommendation API", |
| version="1.0.0" |
| ) |
|
|
| reusable_oauth2 = HTTPBearer( |
| scheme_name="Bearer" |
| ) |
|
|
| class EmbeddingRequest(BaseModel): |
| text: Union[str, List[str]] |
|
|
| class EmbeddingResponse(BaseModel): |
| embeddings: List[List[float]] |
| model: str |
| dimensions: int |
|
|
| class TourSearchRequest(BaseModel): |
| query: str = Field(..., description="The search query for tours, e.g., 'tôi muốn đi đà nẵng'") |
| page: int = Field(1, ge=1, description="Current page number for pagination") |
| limit: int = Field(10, ge=1, le=100, description="Number of items per page") |
|
|
| class TourSummary(BaseModel): |
| tour_id: Any |
| title: str |
| duration: Optional[str] = None |
| departure_location: Optional[str] = None |
| destination: Optional[List[str]] = None |
| region: Optional[str] = None |
| itinerary: Optional[str] = None |
| max_participants: Optional[int] = None |
| departure_id: Optional[Any] = None |
| start_date: Optional[Union[datetime, date]] = None |
| price_adult: Optional[float] = None |
| price_child_120_140: Optional[float] = None |
| price_child_100_120: Optional[float] = None |
| promotion_name: Optional[str] = None |
| promotion_type: Optional[str] = None |
| promotion_discount: Optional[Any] = None |
|
|
| class PaginatedTourResponse(BaseModel): |
| currentPage: int |
| itemsPerPage: int |
| totalItems: int |
| totalPages: int |
| hasNextPage: bool |
| hasPrevPage: bool |
| tours: List[TourSummary] |
|
|
| class TokenData(BaseModel): |
| user_id: Optional[int] = None |
|
|
| async def get_current_user(token: HTTPAuthorizationCredentials = Depends(reusable_oauth2)) -> int: |
| credentials_exception = HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Could not validate credentials", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| try: |
| payload = jwt.decode(token.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) |
|
|
| user_id: Optional[int] = payload.get("id") |
| if user_id is None: |
| user_id = payload.get("userId") |
|
|
| if user_id is None: |
| print(f"JWT payload does not contain 'id' or 'userId' field. Payload: {payload}") |
| raise credentials_exception |
|
|
| except JWTError as e: |
| print(f"JWTError: {e}") |
| raise credentials_exception |
| except Exception as e: |
| print(f"An unexpected error occurred during JWT decoding: {e}") |
| raise credentials_exception |
| return user_id |
|
|
| class ChatMessageInput(BaseModel): |
| message: str = Field(..., description="The text message sent by the user to the chatbot.") |
| session_id: Optional[str] = Field(None, description="An optional identifier for a specific chat session.") |
|
|
| class ChatResponseOutput(BaseModel): |
| user_id: int = Field(..., description="The ID of the user (from JWT token).") |
| response: str = Field(..., description="The chatbot's generated textual response.") |
| session_id: Optional[str] = Field(None, description="The session identifier, mirrored if provided in input.") |
| timestamp: datetime = Field(..., description="UTC timestamp of when the response was generated.") |
|
|
| def get_db_connection(): |
| if conn_pool is None: |
| print("conn_pool is None in get_db_connection. Database module likely not initialized.") |
| raise HTTPException(status_code=503, detail="Database connection pool not initialized. Check src.database and .env configuration.") |
| try: |
| conn = conn_pool.getconn() |
| yield conn |
| finally: |
| if conn: |
| conn_pool.putconn(conn) |
|
|
| def fetch_conversation_history(db_conn, user_id: int, session_id: Optional[str] = None) -> List[BaseMessage]: |
| history: List[BaseMessage] = [] |
| try: |
| with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: |
| if session_id: |
| cursor.execute( |
| "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id = %s ORDER BY interaction_time ASC", |
| (user_id, session_id) |
| ) |
| else: |
| cursor.execute( |
| "SELECT message, response FROM ChatbotHistory WHERE user_id = %s AND session_id IS NULL ORDER BY interaction_time ASC", |
| (user_id,) |
| ) |
| records = cursor.fetchall() |
| for record in records: |
| if record["message"]: |
| history.append(HumanMessage(content=record["message"])) |
| if record["response"]: |
| history.append(AIMessage(content=record["response"])) |
| except Exception as e: |
| print(f"Error fetching conversation history for user_id {user_id}, session_id {session_id}: {e}") |
| return history |
|
|
| def save_interaction_to_history(db_conn, user_id: int, user_message: str, chatbot_response: str, session_id: Optional[str] = None): |
| try: |
| with db_conn.cursor() as cursor: |
| cursor.execute( |
| "INSERT INTO ChatbotHistory (user_id, message, response, interaction_time, session_id) VALUES (%s, %s, %s, %s, %s)", |
| (user_id, user_message, chatbot_response, datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")), session_id) |
| ) |
| db_conn.commit() |
| except Exception as e: |
| print(f"Error saving interaction to history for user_id {user_id}, session_id {session_id}: {e}") |
| db_conn.rollback() |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| try: |
| if embedding_model: |
| embedding_model.load_model() |
| except Exception as e: |
| print(f"Failed to load embedding model on startup: {str(e)}") |
|
|
| @app.get("/", include_in_schema=False, response_class=PlainTextResponse) |
| async def root(): |
| return "API is running" |
|
|
| @app.post("/api/chat/", response_model=ChatResponseOutput, tags=["Chat"], summary="Chat with the Travel Chatbot") |
| async def chat_endpoint(payload: ChatMessageInput, current_user_id: int = Depends(get_current_user), db_conn = Depends(get_db_connection)): |
| if graph_app is None: |
| print("graph_app is None in chat_endpoint. Graph_builder module likely not initialized.") |
| raise HTTPException(status_code=503, detail="Chatbot graph not initialized. Check src.graph_builder.") |
|
|
| user_id = current_user_id |
| user_message_content = payload.message |
| session_id = payload.session_id |
| |
| if not session_id: |
| session_id = str(uuid.uuid4()) |
|
|
| history = fetch_conversation_history(db_conn, user_id, session_id) |
| |
| current_message = HumanMessage(content=user_message_content) |
| all_messages = history + [current_message] |
| |
| inputs = { |
| "messages": all_messages, |
| "user_query": user_message_content, |
| "current_date": None, |
| "available_locations": None, |
| "extracted_entities": None, |
| "search_results": None, |
| "final_response": None, |
| "error": None, |
| "routing_decision": None |
| } |
|
|
| full_response_content = "" |
| try: |
| result = graph_app.invoke(inputs) |
| |
| if isinstance(result, dict) and "final_response" in result: |
| full_response_content = result["final_response"] |
| elif isinstance(result, dict) and "messages" in result and result["messages"]: |
| last_message = result["messages"][-1] |
| if isinstance(last_message, AIMessage): |
| full_response_content = last_message.content |
| |
| if not full_response_content: |
| full_response_content = "Sorry, I could not process your request at this moment." |
|
|
| except Exception as e: |
| print(f"Error during graph invocation for user_id {user_id}: {e}") |
| raise HTTPException(status_code=500, detail=f"Error processing message with chatbot: {str(e)}") |
|
|
| save_interaction_to_history(db_conn, user_id, user_message_content, full_response_content, session_id) |
|
|
| return ChatResponseOutput( |
| user_id=user_id, |
| response=full_response_content, |
| session_id=session_id, |
| timestamp=datetime.now(zoneinfo.ZoneInfo("Asia/Bangkok")) |
| ) |
|
|
| @app.post("/api/tours/search", response_model=PaginatedTourResponse, tags=["Tours"], summary="Search for tours based on user query") |
| async def search_tours_api(request: TourSearchRequest, current_user_id: int = Depends(get_current_user)): |
| current_date_str = date.today().strftime('%Y-%m-%d') |
|
|
| entities = extract_entities_tool(user_query=request.query, current_date_str=current_date_str) |
|
|
| if not entities or (isinstance(entities, dict) and entities.get("error")): |
| return PaginatedTourResponse( |
| currentPage=request.page, |
| itemsPerPage=request.limit, |
| totalItems=0, |
| totalPages=0, |
| hasNextPage=False, |
| hasPrevPage=False, |
| tours=[] |
| ) |
|
|
| all_found_tours = search_tours_tool(entities) |
|
|
| if not all_found_tours: |
| return PaginatedTourResponse( |
| currentPage=request.page, |
| itemsPerPage=request.limit, |
| totalItems=0, |
| totalPages=0, |
| hasNextPage=False, |
| hasPrevPage=False, |
| tours=[] |
| ) |
|
|
| total_items = len(all_found_tours) |
| total_pages = (total_items + request.limit - 1) // request.limit |
|
|
| start_index = (request.page - 1) * request.limit |
| end_index = start_index + request.limit |
| paginated_tours_data = all_found_tours[start_index:end_index] |
|
|
| tour_summaries = [] |
| for tour_data in paginated_tours_data: |
| region = tour_data.get("region") |
| if region is not None and not isinstance(region, str): |
| region = str(region) |
| |
| destination = tour_data.get("destination") |
| if destination and not isinstance(destination, list): |
| destination = [str(destination)] |
| elif destination: |
| destination = [str(dest) for dest in destination] |
| |
| tour_summaries.append(TourSummary( |
| tour_id=tour_data.get("tour_id"), |
| title=str(tour_data.get("title", "")), |
| duration=str(tour_data.get("duration")) if tour_data.get("duration") is not None else None, |
| departure_location=str(tour_data.get("departure_location")) if tour_data.get("departure_location") is not None else None, |
| destination=destination, |
| region=region, |
| itinerary=str(tour_data.get("itinerary")) if tour_data.get("itinerary") is not None else None, |
| max_participants=int(tour_data.get("max_participants")) if tour_data.get("max_participants") is not None else None, |
| departure_id=tour_data.get("departure_id"), |
| start_date=tour_data.get("start_date"), |
| price_adult=float(tour_data.get("price_adult")) if tour_data.get("price_adult") is not None else None, |
| price_child_120_140=float(tour_data.get("price_child_120_140")) if tour_data.get("price_child_120_140") is not None else None, |
| price_child_100_120=float(tour_data.get("price_child_100_120")) if tour_data.get("price_child_100_120") is not None else None, |
| promotion_name=str(tour_data.get("promotion_name")) if tour_data.get("promotion_name") is not None else None, |
| promotion_type=str(tour_data.get("promotion_type")) if tour_data.get("promotion_type") is not None else None, |
| promotion_discount=tour_data.get("promotion_discount") |
| )) |
|
|
| return PaginatedTourResponse( |
| currentPage=request.page, |
| itemsPerPage=request.limit, |
| totalItems=total_items, |
| totalPages=total_pages, |
| hasNextPage=(request.page < total_pages), |
| hasPrevPage=(request.page > 1), |
| tours=tour_summaries |
| ) |
|
|
| @app.get("/api/recommendations", response_model=TourRecommendationResponse, tags=["Recommendations"], summary="Get tour recommendations") |
| async def get_recommendations( |
| user_id: Optional[int] = None, |
| tour_id: Optional[int] = None, |
| limit: int = 3 |
| ): |
| try: |
| if limit < 1 or limit > 10: |
| limit = 3 |
| recommendations = get_tour_recommendations(user_id, tour_id, limit) |
| return recommendations |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}") |
|
|
| @app.post("/api/embed", response_model=EmbeddingResponse, tags=["Embeddings"], summary="Generate text embeddings") |
| async def get_embedding(request: EmbeddingRequest): |
| if embedding_model is None: |
| raise HTTPException( |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| detail="Embedding service not initialized. Check src.embedding module." |
| ) |
|
|
| try: |
| embeddings = embedding_model.get_embedding(request.text) |
|
|
| return { |
| "embeddings": embeddings, |
| "model": embedding_model.model_name, |
| "dimensions": len(embeddings[0]) |
| } |
| except Exception as e: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Error generating embeddings: {str(e)}" |
| ) |