Spaces:
Runtime error
Runtime error
File size: 5,816 Bytes
e68d535 753e3c5 e68d535 753e3c5 e68d535 105a186 e68d535 105a186 e68d535 753e3c5 e68d535 4c250a7 105a186 e68d535 753e3c5 e68d535 753e3c5 e68d535 753e3c5 4ee29ab e68d535 46b11f4 4ee29ab e68d535 4ee29ab e68d535 753e3c5 e68d535 105a186 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from pathlib import Path
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, Form
from fastapi.requests import Request
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from backend.classes.embedding_model import EmbeddingModelConfig, EmbeddingModel
from backend.classes.galileo_platform import GalileoPlatformConfig, GalileoPlatform
from backend.classes.generative_model import GeminiModelConfig, GeminiModel, OpenAIModelConfig, OpenAIModel
from backend.classes.rag_application import RAGApplicationConfig, RAGApplication
from backend.classes.vector_database.milvus_vector_database import (
MilvusVectorDatabaseConfig,
MilvusVectorDatabase,
)
from backend.utils.utils import get_embedding_model
from backend.utils.utils import (
initialize_logger,
read_config,
set_env_variables,
create_vector_database,
get_generative_model,
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="backend/api/static"), name="static")
templates = Jinja2Templates(directory="backend/api/templates")
load_dotenv()
logger = initialize_logger()
# get current file path using Path
config = read_config(str(Path(Path(__file__).parent.parent, "conf/config.yaml")))
# check if environment variables are set
env_variables = set_env_variables(config["env_variables"])
app_config = config[env_variables["APP_ENV"]]
app_config["env_vars"] = env_variables
# Create embedding model object
embedding_model_config = EmbeddingModelConfig(
model_name=app_config["embedding_model"]["model_name"],
batch_size=app_config["embedding_model"]["batch_size"],
)
embedding_model = get_embedding_model(EmbeddingModel, embedding_model_config)
# Create vector db model object
vector_db_config = MilvusVectorDatabaseConfig(
db_path=app_config["vector_database"]["db_path"] + env_variables["MILVUS_DB"] + "_milvus.db",
collection_name=env_variables["MILVUS_DB"],
vector_dimensions=app_config["vector_database"]["dimensions"],
drop_if_exists=False,
)
vector_db = create_vector_database(MilvusVectorDatabase, vector_db_config)
# Create generative model object
gemini_generative_model_config = GeminiModelConfig(
model_name=env_variables["GOOGLE_GEMINI_MODEL"],
api_keys=[env_variables["GOOGLE_GEMINI_API_KEY"], env_variables["GOOGLE_GEMINI_BACKUP_API_KEY"]],
temperature=float(env_variables["MODEL_TEMPERATURE"]),
)
gemini_generative_model = get_generative_model(GeminiModel, gemini_generative_model_config)
# openai_generative_model_config = OpenAIModelConfig(
# model_name=env_variables["OPENAI_MODEL"],
# api_key=env_variables["OPENAI_API_KEY"],
# temperature=float(env_variables["MODEL_TEMPERATURE"]),
# )
# openai_generative_model = get_generative_model(OpenAIModel, openai_generative_model_config)
default_project_name = env_variables["GALILEO_PROJECT_NAME"]
default_logstream_name = env_variables["GALILEO_LOGSTREAM_NAME"]
default_protect_stage_name = env_variables["GALILEO_PROTECT_STAGE_NAME"]
default_dataset_name = env_variables["GALILEO_DATASET_NAME"]
# Create Galileo platform object
galileo_platform_config = GalileoPlatformConfig(
protect_project_name=env_variables["GALILEO_PROJECT_NAME"],
protect_stage_name=env_variables["GALILEO_PROTECT_STAGE_NAME"],
)
galileo_platform = GalileoPlatform(galileo_platform_config)
# Initialize RAG application
rag_application_config = RAGApplicationConfig(
embedding_model=embedding_model,
vector_db=vector_db,
generative_model=gemini_generative_model,
# generative_model=openai_generative_model,
galileo_platform=galileo_platform,
)
rag_app = RAGApplication(rag_application_config)
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
# Get default project name from environment variables
return templates.TemplateResponse("index.html", {
"request": request,
"default_project_name": default_project_name,
"default_logstream_name": default_logstream_name,
"default_dataset_name": default_dataset_name
})
@app.post("/search")
async def search(
query: str = Form(...),
top_k: int = Form(5),
add_to_dataset: bool = Form(False),
pii_detection: bool = Form(False),
hallucination_detection: bool = Form(False),
induce_hallucination: bool = Form(False),
project_name: str = Form(...),
logstream_name: str = Form(...),
dataset_name: str = Form(...),
) -> JSONResponse:
logger.info("=" * 80)
logger.info("SEARCH REQUEST RECEIVED")
logger.info(f"Query: {query}")
logger.info(f"Top K: {top_k}")
logger.info(f"Add to Dataset: {add_to_dataset}")
logger.info(f"PII Detection: {pii_detection}")
logger.info(f"Hallucination Detection: {hallucination_detection}")
logger.info(f"Induce Hallucination: {induce_hallucination}")
logger.info("=" * 80)
response, redacted_response, original_response, context_adherence_score, pii_flag = rag_app.run(
query,
pii_detection=pii_detection,
top_k=top_k,
hallucination_detection=hallucination_detection,
induce_hallucination=induce_hallucination,
project_name=project_name,
logstream_name=logstream_name,
dataset_name=dataset_name if add_to_dataset else None,
)
# Simulate processing
return JSONResponse(
{
"message": response,
"redacted_message": redacted_response,
"original_message": original_response,
"metrics": {
"context_adherence": context_adherence_score,
"pii_flag": pii_flag,
},
}
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |