Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List, Dict | |
| from Embedder.E5_Embeddedr import E5_Embeddedr | |
| from Models.Utils import * | |
| from Models.Prompts import * | |
| from OLAP_Conn.DuckConn import DuckConn | |
| from RAG.RAG_Retrival import RAG_Retrival | |
| import sys | |
| from contextlib import asynccontextmanager | |
| ###################################################### | |
| ####----------------PARAMETER CLASSES----------------- | |
| ###################################################### | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ModelParameter(BaseModel): | |
| model: str | |
| max_token: int | |
| temperature: float | |
| ###################################################### | |
| ####-------------------DEFINATIONS-------------------- | |
| ###################################################### | |
| PATH_DUCK = "Data.duckdb" | |
| db = None | |
| model = None | |
| embedder =None | |
| rag_retriv = None | |
| ###################################################### | |
| ####-----------------STARTUP EVENTS------------------- | |
| ###################################################### | |
| async def lifespan(app: FastAPI): | |
| global db, embedder | |
| db = DuckConn(PATH_DUCK) | |
| embedder = E5_Embeddedr() | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| #################################################### | |
| ####--------------------ROUTES---------------------- | |
| #################################################### | |
| async def chat(messages: List[Message],parameters:ModelParameter): | |
| model = get_specific_model(parameters.model) | |
| model.set_config(temperature=parameters.temperature,max_tokens=parameters.max_token) | |
| rag_retriv = RAG_Retrival(db, model, embedder) | |
| # Convert Pydantic objects to dict | |
| messages_data = [msg.model_dump() for msg in messages] | |
| prompt = messages_data[0]['content'] | |
| relevant_queures= rag_retriv.query_relevant(prompt) | |
| relevant_queures = ''.join(relevant_queures) | |
| final_queury = [message_user(final_prompt(prompt,relevant_queures))] | |
| model_answer = model.send_message(final_queury) | |
| return {"status": "success", "response": model_answer} | |