Spaces:
Runtime error
Runtime error
File size: 4,893 Bytes
e68d535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from pathlib import Path
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, Form
from fastapi.requests import Request
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel
from backend.classes.galileo_platform import GalileoPlatformConfig, GalileoPlatform
from backend.classes.generative_model import GeminiModelConfig, GeminiModel, OpenAIModelConfig, OpenAIModel
from backend.classes.rag_application import RAGApplicationConfig, RAGApplication
from backend.classes.vector_database.milvus_vector_database import (
MilvusVectorDatabaseConfig,
MilvusVectorDatabase,
)
from backend.utils.utils import get_embedding_model
from backend.utils.utils import (
initialize_logger,
read_config,
set_env_variables,
create_vector_database,
get_generative_model,
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="backend/api/static"), name="static")
templates = Jinja2Templates(directory="backend/api/templates")
load_dotenv()
logger = initialize_logger()
# get current file path using Path
config = read_config(str(Path(Path(__file__).parent.parent, "conf/config.yaml")))
# check if environment variables are set
env_variables = set_env_variables(config["env_variables"])
app_config = config[env_variables["APP_ENV"]]
app_config["env_vars"] = env_variables
# Create embedding model object
embedding_model_config = EmbeddingModelConfig(
model_name=app_config["embedding_model"]["model_name"],
batch_size=app_config["embedding_model"]["batch_size"],
)
embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)
# Create vector db model object
vector_db_config = MilvusVectorDatabaseConfig(
db_path=app_config["vector_database"]["db_path"],
collection_name=app_config["vector_database"]["collection_name"],
vector_dimensions=app_config["vector_database"]["dimensions"],
drop_if_exists=False,
)
vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)
# Create generative model object
gemini_generative_model_config = GeminiModelConfig(
model_name=app_config["gemini_generative_model"]["model_name"],
api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]],
temperature=app_config["gemini_generative_model"]["temperature"],
)
gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config)
openai_generative_model_config = OpenAIModelConfig(
model_name=app_config["openai_generative_model"]["model_name"],
api_key=env_variables["OPENAI_API_KEY"],
temperature=app_config["openai_generative_model"]["temperature"],
)
openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config)
# Create Galileo platform object
galileo_platform_config = GalileoPlatformConfig(
evaluate_project_name=app_config["galileo_platform"]["evaluate_project_name"],
observe_project_name=app_config["galileo_platform"]["observe_project_name"],
protect_project_name=app_config["galileo_platform"]["protect_project_name"],
protect_stage_name=app_config["galileo_platform"]["protect_stage_name"],
)
galileo_platform = GalileoPlatform(galileo_platform_config)
# Initialize RAG application
rag_application_config = RAGApplicationConfig(
embedding_model=embedding_model,
vector_db=vector_db,
# gemini_generative_model=gemini_generative_model,
generative_model=openai_generative_model,
galileo_platform=galileo_platform,
)
rag_app = RAGApplication(rag_application_config)
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# TODO: Nikhil
# @app.post("/other-metrics")
# async def search(
@app.post("/search")
async def search(
query: str = Form(...),
top_k: int = Form(5),
protection: bool = Form(False),
hallucination_detection: bool = Form(False),
induce_hallucination: bool = Form(False),
):
response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run(
query,
protect_enabled=protection,
top_k=top_k,
hallucination_detection=hallucination_detection,
induce_hallucination=induce_hallucination,
)
# Simulate processing
return JSONResponse(
{
"message": response,
"redacted_message": redacted_response,
"original_message": original_response,
"metrics": {
"context_adherence": context_adherence_score,
"pii_flag": pii_flag,
},
}
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|