# app/rag/routes.py """ RAG FastAPI routes. This file contains: - /initialization/{onboarding_id}/{doc_type} : ingest documents and create a RAG session - /chat/{onboarding_id}/{doc_type}/{chat_id} : perform a retrieval-augmented chat using stored vectorstore The functions add additional logging to make debugging easier and to surface metrics: - request start/finish times and durations - counts and sizes (documents, chunks, vectors, batches) - Qdrant operations and retries - embedding function selection failures """ import os import json import uuid import time from typing import List, Optional, Iterable from fastapi import APIRouter, HTTPException, Path, Query from pydantic import BaseModel from .schemas import SetupRequest, ChatRequest, SetupResponse, ChatResponse from .utils import ( get_vectorstore_path, save_vectorstore_to_disk, upsert_vectorstore_metadata, get_vectorstore_metadata, build_rag_chain, ) from .chat_history import ChatHistoryManager from .logging_config import logger from qdrant_client import QdrantClient from qdrant_client.models import VectorParams, PointStruct, Distance from app.page_speed.config import settings from .embeddings import embeddings, text_splitter # kept here for ingestion router = APIRouter(prefix="/rag", tags=["rag"]) def _get_embeddings_for_texts(texts: List[str]) -> List[List[float]]: """ Compute embeddings for a list of texts. Tries common bulk methods on the embeddings object and falls back to single-item calls. Logs which method is being attempted and any failures. """ if not texts: logger.debug("_get_embeddings_for_texts called with empty texts list.") return [] logger.debug("Computing embeddings for %d texts", len(texts)) # Preferred bulk API methods to attempt for attr in ("embed_documents", "embed_texts", "embed_batch", "embed"): fn = getattr(embeddings, attr, None) if callable(fn): logger.debug("Trying embedding method: %s", attr) try: vecs = fn(texts) logger.debug("Embedding method %s returned %d vectors", attr, len(vecs) if vecs is not None else 0) return vecs except Exception: logger.debug("Embedding method %s failed; trying next option", attr, exc_info=True) # Fallback to single-item embedding function repeatedly single_fn = getattr(embeddings, "embed_query", None) or getattr(embeddings, "embed", None) if callable(single_fn): logger.debug("Falling back to single-item embedding function: %s", getattr(single_fn, "__name__", "")) vecs = [] for i, t in enumerate(texts): try: vec = single_fn(t) if isinstance(vec, dict) and "embedding" in vec: vecs.append(vec["embedding"]) else: vecs.append(vec) except Exception as e: logger.exception("Single-item embedding failed for text index %d: %s", i, e) raise logger.debug("Single-item embedding produced %d vectors", len(vecs)) return vecs logger.error("Embeddings object does not expose a supported embedding method") raise RuntimeError( "Embeddings object does not expose a supported embedding method " "(embed_documents/embed_texts/embed_query/embed)." ) @router.post("/initialization/{onboarding_id}/{doc_type}", response_model=SetupResponse) async def setup_rag_session( onboarding_id: str = Path(..., description="Unique onboarding identifier"), doc_type: str = Path(..., description="Type of document (e.g., page_speed, seo, content_relevance, uiux or mobile_usability)"), body: SetupRequest = ..., ): """ Ingest documents under a specific document type and create a chat session. Behavior: - If vectorstore metadata exists for onboarding_id and doc_type in DB, skip ingestion (idempotent). - Always create a new chat_id for this session and return it. - Uses Qdrant as the vector store and stores metadata via upsert_vectorstore_metadata. Returns: SetupResponse """ start_ts = time.time() logger.info("RAG initialization called for onboarding_id=%s doc_type=%s", onboarding_id, doc_type) try: # Use DB metadata instead of local filesystem marker existing_meta = get_vectorstore_metadata(onboarding_id, doc_type) if existing_meta: logger.info( "Vectorstore metadata exists for onboarding_id=%s, doc_type=%s; skipping ingestion", onboarding_id, doc_type, ) metadata = existing_meta or {} chat_id = metadata.get("chat_id") or str(uuid.uuid4()) if not ChatHistoryManager.chat_exists(chat_id): ChatHistoryManager.create_session(chat_id) logger.debug("Created new chat session for existing metadata chat_id=%s", chat_id) # ensure DB has chat_id (in case metadata existed but had missing fields) upsert_vectorstore_metadata( onboarding_id, doc_type, metadata.get("vectorstore_path"), chat_id, metadata.get("collection_name"), ) duration = time.time() - start_ts logger.info("RAG initialization skipped ingestion (existing); duration=%.3fs", duration) return SetupResponse( success=True, message="RAG setup completed with existing vectorstore metadata.", onboarding_id=onboarding_id, doc_type=doc_type, chat_id=chat_id, vectorstore_path=metadata.get("vectorstore_path"), ) # New ingestion flow if not body.documents: logger.error( "Missing documents for onboarding_id=%s, doc_type=%s", onboarding_id, doc_type, ) raise HTTPException(status_code=400, detail="Please provide documents to ingest.") logger.info("Ingesting %d documents for %s/%s", len(body.documents), onboarding_id, doc_type) # Create session and ingest chat_id = str(uuid.uuid4()) ChatHistoryManager.create_session(chat_id) logger.debug("Created chat session %s", chat_id) all_text = "\n\n".join(body.documents) text_chunks = text_splitter.split_text(all_text) logger.info("Split documents into %d text chunks", len(text_chunks)) # Build Qdrant client from settings (with timeout + optional prefer_grpc) client_kwargs = {} if getattr(settings, "qdrant_url", None): client_kwargs["url"] = settings.qdrant_url if getattr(settings, "qdrant_api_key", None): client_kwargs["api_key"] = settings.qdrant_api_key qdrant_timeout = getattr(settings, "qdrant_timeout", 60) # seconds (default 60) prefer_grpc = getattr(settings, "qdrant_prefer_grpc", False) try: if client_kwargs: qdrant_client = QdrantClient(**client_kwargs, timeout=qdrant_timeout, prefer_grpc=prefer_grpc) logger.debug("Instantiated QdrantClient with kwargs keys: %s", list(client_kwargs.keys())) else: qdrant_client = QdrantClient(timeout=qdrant_timeout, prefer_grpc=prefer_grpc) logger.debug("Instantiated QdrantClient with default connection (no url/api_key)") except TypeError as e: logger.exception("Failed to instantiate QdrantClient: %s", e) raise HTTPException(status_code=500, detail=f"Failed to construct Qdrant client: {e}") # Deterministic collection name for each onboarding/doc_type collection_name = f"vs_{onboarding_id}_{doc_type}" logger.info("Using Qdrant collection name: %s", collection_name) # -------------------------- # INGEST: compute embeddings # -------------------------- try: vectors = _get_embeddings_for_texts(text_chunks) except Exception as e: logger.exception("Failed to compute embeddings: %s", e) raise HTTPException(status_code=500, detail=f"Embedding error: {e}") if not vectors or len(vectors) != len(text_chunks): logger.error( "Embeddings length mismatch: vectors=%s texts=%s", len(vectors) if vectors is not None else None, len(text_chunks), ) raise HTTPException(status_code=500, detail="Embedding generation failed or returned unexpected shape.") vector_size = len(vectors[0]) if vectors else 0 logger.info("Computed embeddings: count=%d vector_size=%d", len(vectors), vector_size) if vector_size == 0: logger.error("Embedding returned empty vectors (vector_size=0)") raise HTTPException(status_code=500, detail="Embedding returned empty vectors") # Recreate collection (idempotent for onboarding+doc_type) try: qdrant_client.recreate_collection( collection_name=collection_name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) logger.info("Recreated Qdrant collection %s (vector_size=%d)", collection_name, vector_size) except Exception as e: logger.exception("Failed to create/recreate qdrant collection '%s': %s", collection_name, e) raise HTTPException(status_code=500, detail=f"Failed to create qdrant collection: {e}") # Helper: safe upsert with retries/backoff def safe_upsert(client: QdrantClient, collection_name: str, points: Iterable[PointStruct], max_retries: int = 3): attempt = 0 backoff = 1.0 last_exc: Optional[Exception] = None while attempt < max_retries: try: client.upsert(collection_name=collection_name, points=points) logger.debug("Safe upsert successful for %d points (collection=%s) on attempt %d", len(list(points)), collection_name, attempt + 1) return except Exception as exc: last_exc = exc attempt += 1 logger.warning("Qdrant upsert attempt %d/%d failed: %s", attempt, max_retries, exc) if attempt >= max_retries: logger.exception("Qdrant upsert failed after %d attempts", max_retries) raise time.sleep(backoff) backoff *= 2.0 if last_exc: raise last_exc # Upsert points in smaller batches and use safe_upsert batch_size = getattr(settings, "qdrant_upsert_batch_size", 64) points_batch: List[PointStruct] = [] total_points = 0 try: for i, (vec, txt) in enumerate(zip(vectors, text_chunks)): payload = {"text": txt} point_id = str(uuid.uuid4()) point = PointStruct(id=point_id, vector=vec, payload=payload) points_batch.append(point) total_points += 1 if len(points_batch) >= batch_size: logger.debug("Upserting batch of %d points to collection %s (processed=%d)", len(points_batch), collection_name, total_points) safe_upsert(qdrant_client, collection_name, points_batch) points_batch = [] # final flush if points_batch: logger.debug("Upserting final batch of %d points to collection %s (processed=%d)", len(points_batch), collection_name, total_points) safe_upsert(qdrant_client, collection_name, points_batch) logger.info("Upserted total %d points into Qdrant collection %s", total_points, collection_name) except Exception as e: logger.exception("Failed to upsert points into qdrant: %s", e) raise HTTPException(status_code=500, detail=f"Failed to upsert points into Qdrant: {e}") # Create an in-application "vectorstore_path" (URI-style) and store metadata in DB try: vs_path = save_vectorstore_to_disk( onboarding_id, doc_type, collection_name, getattr(settings, "qdrant_url", None), getattr(settings, "qdrant_api_key", None), ) logger.debug("Saved vectorstore metadata path: %s", vs_path) except Exception as e: logger.exception("Failed to save vectorstore metadata to disk/DB: %s", e) raise HTTPException(status_code=500, detail=f"Failed to persist vectorstore metadata: {e}") # Persist metadata into MongoDB (no local disk involved) try: upsert_vectorstore_metadata(onboarding_id, doc_type, vs_path, chat_id, collection_name) logger.info("Persisted vectorstore metadata for %s/%s (chat_id=%s)", onboarding_id, doc_type, chat_id) except Exception as e: logger.exception("Failed to upsert vectorstore metadata into DB: %s", e) raise HTTPException(status_code=500, detail=f"Failed to persist vectorstore metadata: {e}") duration = time.time() - start_ts logger.info( "Created Qdrant collection %s for %s/%s (points=%d) in %.3fs", collection_name, onboarding_id, doc_type, total_points, duration, ) return SetupResponse( success=True, message="RAG setup completed.", onboarding_id=onboarding_id, doc_type=doc_type, chat_id=chat_id, vectorstore_path=vs_path, ) except HTTPException: # Re-raise HTTP exceptions (already logged above) raise except Exception as exc: logger.exception("Unhandled exception during RAG initialization for %s/%s: %s", onboarding_id, doc_type, exc) raise HTTPException(status_code=500, detail=f"Internal server error during RAG initialization: {exc}") @router.post("/chat/{onboarding_id}/{doc_type}/{chat_id}", response_model=ChatResponse) async def chat_with_user( onboarding_id: str = Path(...), doc_type: str = Path(...), chat_id: str = Path(...), prompt_type: str = Query(..., description="Prompt type, e.g., page_speed, content_relevance, seo, uiux or mobile_usability"), body: ChatRequest = ..., ): """ Chat endpoint using a specific document-type vectorstore. Steps: - Verify vectorstore metadata exists. - Ensure chat session exists. - Optionally summarize history. - Build the RAG chain and invoke it with the question + chat_history. - Persist AI/human turns into ChatHistoryManager. """ start_ts = time.time() logger.info("Chat request received: onboarding_id=%s doc_type=%s chat_id=%s prompt_type=%s", onboarding_id, doc_type, chat_id, prompt_type) try: # Use DB metadata instead of local filesystem marker metadata = get_vectorstore_metadata(onboarding_id, doc_type) if not metadata: logger.warning("Vectorstore metadata not found for %s/%s", onboarding_id, doc_type) raise HTTPException(status_code=400, detail="Vectorstore metadata not found; run initialization first.") if not ChatHistoryManager.chat_exists(chat_id): logger.warning("Chat session %s not found", chat_id) raise HTTPException(status_code=404, detail=f"Chat session {chat_id} not found.") question = (body.question or "").strip() if not question: logger.warning("Empty question in chat request for chat_id=%s", chat_id) raise HTTPException(status_code=400, detail="Question cannot be empty.") logger.info("Processing question (len=%d) for chat_id=%s", len(question), chat_id) ChatHistoryManager.summarize_if_needed(chat_id, threshold=10) ChatHistoryManager.add_message(chat_id, role="human", content=question) logger.debug("Added human message to history for chat_id=%s", chat_id) chain = build_rag_chain(onboarding_id, doc_type, chat_id, prompt_type) logger.debug("Built RAG chain for onboarding_id=%s doc_type=%s chat_id=%s", onboarding_id, doc_type, chat_id) history = ChatHistoryManager.get_messages(chat_id) logger.debug("Chat history length=%d for chat_id=%s", len(history), chat_id) try: result = chain.invoke({"question": question, "chat_history": history}) logger.debug("RAG chain invoked successfully for chat_id=%s", chat_id) except Exception as e: logger.exception("RAG chain invocation failed for chat_id=%s: %s", chat_id, e) raise HTTPException(status_code=500, detail=f"RAG chain invocation failed: {e}") answer = result.get("answer") or result.get("output_text") or "" logger.info("Generated answer length=%d for chat_id=%s", len(answer), chat_id) ChatHistoryManager.add_message(chat_id, role="ai", content=answer) duration = time.time() - start_ts logger.info("Chat request completed for chat_id=%s duration=%.3fs", chat_id, duration) return ChatResponse( success=True, answer=answer, error=None, chat_id=chat_id, onboarding_id=onboarding_id, doc_type=doc_type, ) except HTTPException: # Re-raise HTTP exceptions (already logged above) raise except Exception as exc: logger.exception("Unhandled exception during chat for %s/%s chat_id=%s: %s", onboarding_id, doc_type, chat_id, exc) raise HTTPException(status_code=500, detail=f"Internal server error during chat: {exc}")