meet4150/ALIV_AI / app /main.py
download
raw
6.34 kB
from __future__ import annotations
import os
from pathlib import Path
from uuid import uuid4
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from pydantic import BaseModel
from app.agent.health_agent import HealthAgent
from app.agent.kb_embedding import KBEmbeddingService
from app.agent.kb_retrieval import validate_similarity
from app.celery_app import celery_app
from app.db.chroma_client import get_vector_backend, safe_count
from app.ingestion.pipeline import IngestionOptions, ingest_file, ingest_text
from app.nlp.nlp_service import NLPService
from app.tasks.ingestion_tasks import ingest_file_task, ingest_text_task
app = FastAPI(title="AliveAI Medical RAG Chatbot")
SESSION_AGENTS: dict[str, HealthAgent] = {}
DEFAULT_MODEL = os.getenv("ALIVEAI_HEALTH_MODEL", "aaditya/Llama3-OpenBioLLM-8B")
UPLOAD_DIR = Path(__file__).resolve().parents[1] / "data" / "uploads"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
class ChatRequest(BaseModel):
message: str
session_id: str = "default"
class IngestTextRequest(BaseModel):
text: str
source: str | None = None
disease_id: str | None = None
topic: str = "general"
document_id: str | None = None
chunk_size: int | None = None
chunk_overlap: int | None = None
batch_size: int | None = None
scraped_at: str | None = None
async_process: bool = True
def get_agent(session_id: str) -> HealthAgent:
if session_id not in SESSION_AGENTS:
SESSION_AGENTS[session_id] = HealthAgent(model=DEFAULT_MODEL)
return SESSION_AGENTS[session_id]
@app.on_event("startup")
async def startup_event() -> None:
NLPService()
KBEmbeddingService().load_model()
print("All models loaded. Ready.")
@app.post("/chat")
async def chat(request: ChatRequest) -> dict:
agent = get_agent(request.session_id)
return agent.chat(request.message)
@app.get("/health")
async def health() -> dict:
vector_count = safe_count()
return {
"status": "ok",
"backend": get_vector_backend(),
"vector_count": vector_count,
"chroma_count": vector_count,
}
@app.post("/reset/{session_id}")
async def reset(session_id: str) -> dict:
agent = get_agent(session_id)
agent.reset()
return {"status": "reset", "session_id": session_id}
@app.get("/validate")
async def validate(text1: str, text2: str) -> dict:
score = validate_similarity(text1, text2)
return {"text1": text1, "text2": text2, "similarity": score}
@app.get("/ingest/schema")
async def ingest_schema() -> dict:
return {
"embedding_model": "BAAI/bge-base-en-v1.5",
"embedding_dimension": KBEmbeddingService().embedding_dimension(),
"vector_backend": get_vector_backend(),
"ingestion_backend": "chroma",
"data_format": {
"id": "string",
"content": "string",
"metadata": {
"disease_id": "string",
"topic": "string",
"source": "string",
"document_id": "string",
"chunk_index": "integer",
"scraped_at": "YYYY-MM-DD",
},
},
"rag_parameters": {
"chunk_size": int(os.getenv("ALIVEAI_CHUNK_SIZE", "700")),
"chunk_overlap": int(os.getenv("ALIVEAI_CHUNK_OVERLAP", "150")),
"top_k": int(os.getenv("ALIVEAI_RAG_TOP_K", "5")),
"top_p": float(os.getenv("ALIVEAI_LLM_TOP_P", "0.9")),
"llm_top_k": int(os.getenv("ALIVEAI_LLM_TOP_K", "40")),
},
}
@app.post("/ingest/text")
async def ingest_text_endpoint(payload: IngestTextRequest) -> dict:
options = IngestionOptions.from_dict(
{
"source": payload.source,
"disease_id": payload.disease_id,
"topic": payload.topic,
"document_id": payload.document_id,
"chunk_size": payload.chunk_size,
"chunk_overlap": payload.chunk_overlap,
"batch_size": payload.batch_size,
"scraped_at": payload.scraped_at,
"vector_backend": "chroma",
}
)
if payload.async_process:
task = ingest_text_task.delay(payload.text, options.__dict__)
return {
"status": "queued",
"task_id": task.id,
}
return ingest_text(payload.text, options)
@app.post("/ingest/file")
async def ingest_file_endpoint(
file: UploadFile = File(...),
source: str | None = Form(default=None),
disease_id: str | None = Form(default=None),
topic: str = Form(default="general"),
document_id: str | None = Form(default=None),
chunk_size: int | None = Form(default=None),
chunk_overlap: int | None = Form(default=None),
batch_size: int | None = Form(default=None),
scraped_at: str | None = Form(default=None),
async_process: bool = Form(default=True),
) -> dict:
suffix = Path(file.filename or "").suffix.lower()
if suffix not in {".txt", ".pdf", ".doc", ".docx"}:
raise HTTPException(status_code=400, detail="Supported file formats: .txt, .pdf, .doc, .docx")
stored_name = f"{uuid4().hex}{suffix}"
stored_path = UPLOAD_DIR / stored_name
raw_content = await file.read()
stored_path.write_bytes(raw_content)
options = IngestionOptions.from_dict(
{
"source": source or file.filename,
"disease_id": disease_id,
"topic": topic,
"document_id": document_id,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"batch_size": batch_size,
"scraped_at": scraped_at,
"vector_backend": "chroma",
}
)
if async_process:
task = ingest_file_task.delay(str(stored_path), options.__dict__)
return {
"status": "queued",
"task_id": task.id,
"file_path": str(stored_path),
}
return ingest_file(stored_path, options)
@app.get("/ingest/task/{task_id}")
async def ingestion_task_status(task_id: str) -> dict:
result = celery_app.AsyncResult(task_id)
payload = {
"task_id": task_id,
"status": result.status,
}
if result.successful():
payload["result"] = result.result
elif result.failed():
payload["error"] = str(result.result)
return payload

Xet Storage Details

Size:
6.34 kB
·
Xet hash:
556acf8f04dc76846cb323e2563aaca25f40f03c93a3ef9d51bd1034581d59a9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.