from pathlib import Path import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, Form from fastapi.requests import Request from fastapi.responses import HTMLResponse from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel from backend.classes.galileo_platform import GalileoPlatformConfig, GalileoPlatform from backend.classes.generative_model import GeminiModelConfig, GeminiModel, OpenAIModelConfig, OpenAIModel from backend.classes.rag_application import RAGApplicationConfig, RAGApplication from backend.classes.vector_database.milvus_vector_database import ( MilvusVectorDatabaseConfig, MilvusVectorDatabase, ) from backend.utils.utils import get_embedding_model from backend.utils.utils import ( initialize_logger, read_config, set_env_variables, create_vector_database, get_generative_model, ) app = FastAPI() app.mount("/static", StaticFiles(directory="backend/api/static"), name="static") templates = Jinja2Templates(directory="backend/api/templates") load_dotenv() logger = initialize_logger() # get current file path using Path config = read_config(str(Path(Path(__file__).parent.parent, "conf/config.yaml"))) # check if environment variables are set env_variables = set_env_variables(config["env_variables"]) app_config = config[env_variables["APP_ENV"]] app_config["env_vars"] = env_variables # Create embedding model object embedding_model_config = EmbeddingModelConfig( model_name=app_config["embedding_model"]["model_name"], batch_size=app_config["embedding_model"]["batch_size"], ) embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config) # Create vector db model object vector_db_config = MilvusVectorDatabaseConfig( db_path=app_config["vector_database"]["db_path"] + env_variables["MILVUS_DB"] + "_milvus.db", collection_name=env_variables["MILVUS_DB"], vector_dimensions=app_config["vector_database"]["dimensions"], drop_if_exists=False, ) vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config) # Create generative model object gemini_generative_model_config = GeminiModelConfig( model_name=env_variables["GOOGLE_GEMINI_MODEL"], api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]], temperature=float(env_variables["MODEL_TEMPERATURE"]), ) gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config) # openai_generative_model_config = OpenAIModelConfig( # model_name=env_variables["OPENAI_MODEL"], # api_key=env_variables["OPENAI_API_KEY"], # temperature=float(env_variables["MODEL_TEMPERATURE"]), # ) # openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config) default_project_name = env_variables["GALILEO_PROJECT_NAME"] default_logstream_name = env_variables["GALILEO_LOGSTREAM_NAME"] default_protect_stage_name = env_variables["GALILEO_PROTECT_STAGE_NAME"] default_dataset_name = env_variables["GALILEO_DATASET_NAME"] # Create Galileo platform object galileo_platform_config = GalileoPlatformConfig( protect_project_name=env_variables["GALILEO_PROJECT_NAME"], protect_stage_name=env_variables["GALILEO_PROTECT_STAGE_NAME"], ) galileo_platform = GalileoPlatform(galileo_platform_config) # Initialize RAG application rag_application_config = RAGApplicationConfig( embedding_model=embedding_model, vector_db=vector_db, generative_model=gemini_generative_model, # generative_model=openai_generative_model, galileo_platform=galileo_platform, ) rag_app = RAGApplication(rag_application_config) @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): # Get default project name from environment variables return templates.TemplateResponse("index.html", { "request": request, "default_project_name": default_project_name, "default_logstream_name": default_logstream_name, "default_dataset_name": default_dataset_name }) @app.post("/search") async def search( query: str = Form(...), top_k: int = Form(5), add_to_dataset: bool = Form(False), pii_detection: bool = Form(False), hallucination_detection: bool = Form(False), induce_hallucination: bool = Form(False), project_name: str = Form(...), logstream_name: str = Form(...), dataset_name: str = Form(...), ) -> JSONResponse: logger.info("=" * 80) logger.info("SEARCH REQUEST RECEIVED") logger.info(f"Query: {query}") logger.info(f"Top K: {top_k}") logger.info(f"Add to Dataset: {add_to_dataset}") logger.info(f"PII Detection: {pii_detection}") logger.info(f"Hallucination Detection: {hallucination_detection}") logger.info(f"Induce Hallucination: {induce_hallucination}") logger.info("=" * 80) response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run( query, pii_detection=pii_detection, top_k=top_k, hallucination_detection=hallucination_detection, induce_hallucination=induce_hallucination, project_name=project_name, logstream_name=logstream_name, dataset_name=dataset_name if add_to_dataset else None, ) # Simulate processing return JSONResponse( { "message": response, "redacted_message": redacted_response, "original_message": original_response, "metrics": { "context_adherence": context_adherence_score, "pii_flag": pii_flag, }, } ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)